In [1]:
from pathlib import Path
import torch


model_checkpoint = 'distilgpt2'
block_size = 128

data_path = Path('./data/Shakespeare')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

## Loading data

In [2]:
from datasets import DatasetDict, load_dataset


datasets = load_dataset('text', data_dir=data_path)

train_valid_split = datasets["train"].train_test_split(test_size=0.1)
datasets = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['test']
})
print(f"datasets:\n{datasets}")
datasets["train"][:5]['text']

  from .autonotebook import tqdm as notebook_tqdm


datasets:
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 38967
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 4330
    })
})


['the map of my microcosm, follows it that I am known',
 "Shall bear along impawn'd, away to-night!",
 'My evils conjured to remembrance and',
 'POLIXENES:',
 '']

## Preparing data
Tokenizing & tranforming to sequences of fixed length.

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [4]:
def tokenize_function(examples, tokenizer=tokenizer):
    return tokenizer(examples["text"])

In [5]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Map (num_proc=4): 100%|██████████| 38967/38967 [00:06<00:00, 6090.40 examples/s]
Map (num_proc=4): 100%|██████████| 4330/4330 [00:07<00:00, 563.98 examples/s]


In [6]:
tokenized_datasets["train"][1]

{'input_ids': [2484,
  439,
  6842,
  1863,
  848,
  3832,
  1549,
  11,
  1497,
  284,
  12,
  3847,
  0],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [7]:
def group_texts(examples, block_size=block_size):
    f"""
    For each train, test, split - concatenates all texts and
    divides them into sequence of length {block_size}.
    Extra text at last which couldnt make a full block is discarded.
    """
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [8]:
processed_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=512,
    num_proc=4,
)

Map (num_proc=4): 100%|██████████| 38967/38967 [00:03<00:00, 10091.82 examples/s]
Map (num_proc=4): 100%|██████████| 4330/4330 [00:03<00:00, 1126.17 examples/s]


In [9]:
print(f"Tokenized & processed dataset:\n{processed_datasets}")
tokenizer.decode(processed_datasets["train"][0]['input_ids'])

Tokenized & processed dataset:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2243
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 247
    })
})


"the map of my microcosm, follows it that I am knownShall bear along impawn'd, away to-night!My evils conjured to remembrance andPOLIXENES:The royal fool thou copest with,--Quit their own part, and in obsequious fondnessOff with the crown, and with the crown his head;Where is the duke? 'tis he should hear me speak.Ah, know you not the city favours them,And say 'Alas, it was a piteous deed!'Who bare my letter, then, to Romeo?He did not come to hope"

# Model

In [10]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

In [11]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


## Pre-train performance

In [12]:
def generate_sample(seed: str, model):
    input_ids = tokenizer.encode(seed, return_tensors='pt').to(model.device)
    attention_mask = torch.ones(input_ids.shape, device=model.device)

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [13]:
print(generate_sample("JULIET:", model=model))

JULIET: The U.S. government has been accused of using a U.S. drone to kill suspected militants in the southern Indian state of Andhra Pradesh, a state government official said on Wednesday.








# Training

In [14]:
from transformers import Trainer, TrainingArguments

In [15]:
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned-wikitext2",
    eval_strategy = "epoch",
    learning_rate = 2e-5,
    weight_decay = 0.01
)

In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
)

In [17]:
USE_FLASH_ATTENTION=1
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(
                                                 
 33%|███▎      | 281/843 [01:05<01:57,  4.79it/s]

{'eval_loss': 4.753508567810059, 'eval_runtime': 2.2207, 'eval_samples_per_second': 111.226, 'eval_steps_per_second': 13.959, 'epoch': 1.0}


 59%|█████▉    | 500/843 [01:57<01:17,  4.41it/s]

{'loss': 4.9407, 'grad_norm': 6.589237689971924, 'learning_rate': 8.137603795966786e-06, 'epoch': 1.78}


                                                 
 67%|██████▋   | 562/843 [02:15<00:53,  5.21it/s]

{'eval_loss': 4.6606669425964355, 'eval_runtime': 2.0283, 'eval_samples_per_second': 121.778, 'eval_steps_per_second': 15.284, 'epoch': 2.0}


                                                 
100%|██████████| 843/843 [03:22<00:00,  4.16it/s]

{'eval_loss': 4.640899658203125, 'eval_runtime': 2.087, 'eval_samples_per_second': 118.35, 'eval_steps_per_second': 14.854, 'epoch': 3.0}
{'train_runtime': 202.7359, 'train_samples_per_second': 33.191, 'train_steps_per_second': 4.158, 'train_loss': 4.838235260045967, 'epoch': 3.0}





TrainOutput(global_step=843, training_loss=4.838235260045967, metrics={'train_runtime': 202.7359, 'train_samples_per_second': 33.191, 'train_steps_per_second': 4.158, 'total_flos': 219783229341696.0, 'train_loss': 4.838235260045967, 'epoch': 3.0})

# Generation

In [22]:
print(generate_sample("JULIET:", model=model))

To be or not, and I'll be, and I'll be.And I'll be, and I'll be, and I'll be,And I'll be, and I'll be,And I'll be, and I'll be,
