## Training

In [1]:
device = "cpu"
model_id = "gpt2"

### Dataset

In [2]:
from datasets import load_dataset

# Note: this is the same dataset as https://urldefense.com/v3/__https://pytorch.org/text/stable/datasets.html*id22__;Iw!!LIr3w8kk_Xxm!oJNtg0Dcg0AZd7jpP-TKv-pOUtoxBQ668RwcOjO1YIHTzTC8ZBVbXkyntoc9YijqdBKbGukpcgzchLbesQ$

raw_dataset = load_dataset("wikitext", "wikitext-2-v1")

Found cached dataset wikitext (C:/Users/Atul Gandre/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

### Tokenizer

In [3]:
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

### Model

In [4]:
from transformers import GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=128,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 122.4M parameters


In [5]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

### Preprocessing

In [7]:
processed_dataset = (raw_dataset
    .filter(lambda example: len(example['text']) > 0)
    .map(
        lambda example: tokenizer(
            example['text'], 
            max_length=128, 
            truncation=True, 
            padding='max_length',
            return_tensors="pt",
            return_attention_mask=True,
        ),
        batched=True
    )
)

processed_dataset = processed_dataset.remove_columns('text')
processed_dataset.set_format('torch')

Loading cached processed dataset at C:\Users\Atul Gandre\.cache\huggingface\datasets\wikitext\wikitext-2-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-0ef9e62734397d3d.arrow
Loading cached processed dataset at C:\Users\Atul Gandre\.cache\huggingface\datasets\wikitext\wikitext-2-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-516f1284a2d69512.arrow
Loading cached processed dataset at C:\Users\Atul Gandre\.cache\huggingface\datasets\wikitext\wikitext-2-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-e234ab76a2a238b2.arrow
Loading cached processed dataset at C:\Users\Atul Gandre\.cache\huggingface\datasets\wikitext\wikitext-2-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-dffbbba7d859db17.arrow
Loading cached processed dataset at C:\Users\Atul Gandre\.cache\huggingface\datasets\wikitext\wikitext-2-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f

Map:   0%|          | 0/2461 [00:00<?, ? examples/s]

In [8]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="wikitext-2-gpt2",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=10,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
)

In [None]:
trainer.train()

In [13]:
trainer.save_model()

## Perplexity

In [None]:
from datasets import Dataset
from torch.utils.data import DataLoader
import torch



def perplexity(model, dataloader: DataLoader, device: str):
  losses = []
  for i, data in enumerate(dataloader):
    if i == 10:
      break
    seq_len = data['input_ids'].size(1)

    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    target_ids = input_ids.clone()
    target_ids[:-seq_len] = -100

    with torch.no_grad():
      outputs = model(input_ids, attention_mask=attention_mask, labels=target_ids)
      
      shift_logits = outputs.logits[..., :-1, :].contiguous()
      shift_labels = target_ids[..., 1:].contiguous()
      shift_attn_mask = attention_mask[..., 1:].contiguous()
      
      # Flatten the tokens
      loss = torch.nn.functional.cross_entropy(
          shift_logits.permute(0, 2, 1), 
          shift_labels, 
          reduction='none'
      ) * shift_attn_mask

    losses.append(loss.sum(dim=-1)/shift_attn_mask.sum(-1))

  return torch.exp(torch.stack(losses).mean())

dataloader = DataLoader(processed_dataset['validation'], batch_size=4)
perplexity(model, dataloader, "cpu")