In [ ]:
from transformers import GPT2LMHeadModel, GPT2Config
from datasets import load_dataset
from torch import nn
from tqdm import tqdm
from transformers import get_scheduler
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.utils.data import DataLoader
#  Load the custom tokenizer
from transformers import PreTrainedTokenizerFast

In [ ]:
# Define the block size (e.g., 1024 tokens for GPT-2)
block_size = 1024
# Set training parameters
num_epochs = 50
batch_size = 64

config = GPT2Config(
    vocab_size=50257,    # Size of your vocabulary (adjust to match your tokenizer)
    n_positions=1024,    # Maximum sequence length
    n_ctx=1024,          # Context window size
    n_embd=768,          # Embedding size
    n_layer=12,          # Number of transformer layers
    n_head=12,           # Number of attention heads
    pad_token_id=50256,  # Set padding token ID (e.g., same as eos_token)
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [ ]:
# Instantiate GPT-2 model
model = GPT2LMHeadModel(config)

# Load the tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained('custom_tokenizer')

In [ ]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples['text'], return_special_tokens_mask=True)

# Apply tokenization
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text'])

In [ ]:
# Function to group texts
def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}  # Concatenate all texts
    total_length = len(concatenated['input_ids'])
    # Drop the last chunk if it's smaller than block_size
    total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated.items()
    }
    result['labels'] = result['input_ids'].copy()  # Labels are the same as input_ids for causal LM
    return result

# Group the tokenized dataset
lm_datasets = tokenized_datasets.map(group_texts, batched=True)

In [ ]:
# Convert datasets to PyTorch format
train_dataset = lm_datasets['train']
eval_dataset = lm_datasets['validation']

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)

In [ ]:
# Create tensorboard logger
writer = SummaryWriter(log_dir='runs/gpt2_text_generation')

# Define optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name='linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Define loss function (CrossEntropyLoss with ignore_index for padding tokens)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Training loop
progress_bar = tqdm(range(num_training_steps))
model.train()
model.to(device)

train_loss = []
eval_loss = []
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):
        # Get input_ids and labels
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        labels=labels)
        loss = outputs.loss  # GPT-2 directly computes the loss if labels are provided
        
        # Log the loss
        detached_loss = loss.detach().cpu().item()
        writer.add_scalar('Loss/train', detached_loss, epoch * len(train_dataloader) + batch_idx)
        total_loss += detached_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
    
    train_loss.append(total_loss)
    total_loss = 0
    
    # Evaluation
    model.eval()
    for batch_idx, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = input_ids.clone()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            # Log loss
            detached_loss = loss.detach().cpu().item()
            writer.add_scalar('Loss/eval', detached_loss, epoch * len(eval_dataloader) + batch_idx)
            total_loss += detached_loss
    
    eval_loss.append(total_loss)
    print(f'End of epoch: {epoch}')
    print(f'Training loss: {train_loss[-1]}')
    print(f'Eval loss: {eval_loss[-1]}')
    print(' ')
    
    model.train()

print('Training completed!')
writer.close()

In [ ]:
torch.save(model, 'gpt_model.ph')