In [1]:
!pip install -qq nlp==0.2.0
!pip install -qq transformers==2.10.0

In [2]:
!pip install datasets



In [3]:
from transformers import (
    ReformerModelWithLMHead,
    ReformerTokenizer,
    ReformerConfig,
    Trainer,
    DataCollator,
    TrainingArguments,
)
import torch
import os
import re
import pandas as pd
import nlp
import datasets

In [4]:
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment",pad_token="<pad>",bos_token='<s>',eos_token='</s>')
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')



In [5]:
FLICKR_PATH = ""
with open(os.path.join(FLICKR_PATH, "Flickr8k.token.txt")) as tokens:
    a=tokens.readlines()
tokens=[re.split("\n|#[0-9]\t",i)[1] for i in a]

In [6]:
my_dict={'lines': tokens}
dataset =datasets.Dataset.from_dict(my_dict)

In [7]:
sequence_length = 2 ** 19  # 524288

# define our map function to reduce the dataset to one sample
def flatten_and_tokenize(batch):
  all_input_text = ['<s>']+["</s> <s>".join(batch["lines"])]+['</s>']
  input_ids_dict = tokenizer.batch_encode_plus(
      all_input_text, pad_to_max_length=True, max_length=sequence_length
  )

    # duplicate data 8 times to have have 8 examples in dataset
  for key in input_ids_dict.keys():
    input_ids_dict[key] = [ [x] for x in input_ids_dict[key]][0]

  return input_ids_dict

# reduce the dataset and set batch_size to all inputs
dataset = dataset.map(
  flatten_and_tokenize, batched=True, batch_size=256, remove_columns=["lines"]
)

# prepare dataset to be in torch format
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

HBox(children=(FloatProgress(value=0.0, max=159.0), HTML(value='')))




In [8]:
#sequence_length=2**19
#def flatten_and_tokenize(batch):
#    all_input_text = ["".join(batch["lines"])]
##    input_ids_dict = tokenizer.batch_encode_plus(all_input_text, padding='max_length',max_length=sequence_length,truncation = True)
#    for key in input_ids_dict.keys():
#        input_ids_dict[key] = [[x[0] for x in input_ids_dict[key] if x]]*8
#    return input_ids_dict
#dataset = dataset.map(flatten_and_tokenize, batched=True, batch_size=1024, remove_columns=["lines"])
#dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [9]:
class ReformerCollator():
    def __init__(self, max_roll_length):
        self.max_roll_length = max_roll_length
    def collate_batch(self, features):
        # get random shift int
        random_shift_length = torch.randint(self.max_roll_length, (1,)).item()
        # shift input and mask
        rolled_input_ids = torch.roll(
            features[0]["input_ids"], random_shift_length
        ).unsqueeze(0)
        rolled_attention_mask = torch.roll(
            features[0]["attention_mask"], random_shift_length
        ).unsqueeze(0)
        return {
            "input_ids": rolled_input_ids,  # BS x SEQ_LEN
            "labels": rolled_input_ids,  # BS x SEQ_LEN
            "attention_mask": rolled_attention_mask,  # BS x SEQ_LEN
        }

In [10]:
non_padded_sequence_length = sequence_length - sum(dataset["attention_mask"][0])
data_collator = ReformerCollator(non_padded_sequence_length)

In [11]:
#for name, param in model.named_parameters():
#    if name not in ['reformer.encoder.layer_norm.weight','reformer.encoder.layer_norm.bias',
#                'lm_head.bias','lm_head.decoder.weight']:
#        param.requires_grad = False

In [12]:
training_args = {
    "learning_rate": 1e-3,
    "max_steps": 100,
    "do_train": True,
    "evaluate_during_training": True,
    "gradient_accumulation_steps": 8,
    "logging_steps": 50,
    "warmup_steps": 500,
    "weight_decay": 0.001,
    "fp16": False,
    "save_steps": 50,
    "output_dir": "./"
}

training_args = TrainingArguments(**training_args)

In [13]:
def compute_metrics(pred):
    non_padded_indices = (pred.label_ids != -100)

    # correctly shift labels and pred as it's done in forward()
    labels = pred.label_ids[..., 1:][non_padded_indices[..., 1:]]
    pred = np.argmax(pred.predictions[:, :-1], axis=-1)[non_padded_indices[..., :-1]]

    acc = np.mean(np.asarray(pred == labels), dtype=np.float)
    return {"accuracy": acc}

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    train_dataset=dataset,
    eval_dataset=dataset,
    prediction_loss_only=True,
)

# train
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=51.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…

  return torch.tensor(x, **format_kwargs)





HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…

{"loss": 3.3544596575336074, "learning_rate": 0.0001, "epoch": 24.8, "step": 50}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=20.0, style=ProgressStyle(description_wi…


{"eval_loss": 5.926865887886379e-05, "epoch": 24.8, "step": 50}







HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…

{"loss": 7.298398649254523e-05, "learning_rate": 0.0002, "epoch": 49.8, "step": 100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=20.0, style=ProgressStyle(description_wi…


{"eval_loss": 5.738620566262398e-05, "epoch": 49.8, "step": 100}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=20.0, style=ProgressStyle(description_wid…





TrainOutput(global_step=101, training_loss=1.6606605506620433)

In [15]:
torch.save(model, './finetuned_full.pt')

In [16]:
model.eval()



ReformerModelWithLMHead(
  (reformer): ReformerModel(
    (embeddings): ReformerEmbeddings(
      (word_embeddings): Embedding(320, 256)
      (position_embeddings): AxialPositionEmbeddings(
        (weights): ParameterList(
            (0): Parameter containing: [torch.cuda.FloatTensor of size 512x1x64 (GPU 0)]
            (1): Parameter containing: [torch.cuda.FloatTensor of size 1x1024x192 (GPU 0)]
        )
      )
    )
    (encoder): ReformerEncoder(
      (layers): ModuleList(
        (0): ReformerLayer(
          (attention): ReformerAttention(
            (layer_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (self_attention): LocalSelfAttention(
              (query): Linear(in_features=256, out_features=128, bias=False)
              (key): Linear(in_features=256, out_features=128, bias=False)
              (value): Linear(in_features=256, out_features=128, bias=False)
            )
            (output): ReformerSelfOutput(
              (dense): Lin