# PyTorch training script
The following notebook will demonstrate how I trained TinyMistral.

# Installing libraries
We'll install any dependency that we need for this notebook.

In [None]:
# I'll also be installing one of my other packages for a model that I have private.
# This package contains nice utilities that I'd rather not code again.
!pip install --upgrade sentia datasets transformers evaluate rouge_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from evaluate import load as load_metric
import sacrebleu
from tqdm import tqdm
import math
from sentia import SENTIAForCausalLM,
import wandb
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from typing import Optional
from dataclasses import dataclass, field
import torch.nn.functional as F

# String to dtype mapping

We'll define a constant to provide an easy way to define dtypes on the command line. (This would be a bit more relevant if it was not a notebook)

In [None]:
STRING_TO_DTYPE_MAPPING = {
    "bfloat16": torch.bfloat16,
     "float32": torch.float32,
     "float16": torch.float16,
     "float64": torch.float64,
}


# Dataset classes

We'll define the dataset classes that will be used for data preprocessing

In [None]:
class ConversationDataset(Dataset):
    def __init__(self, tokenizer, max_length=512, data=None, device="cuda"):
        self.data = data
        self.device = device
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            # Most of the time I'll be using InstructMix for instruction-tuning
            user = self.data[idx]["Input"]
            assistant = self.data[idx]["Output"]
        except KeyError:
            # If I'm using MMLU for evaluation
            user = self.data[idx]["question"]
            ans_index = self.data[idx]["answer"]
            assistant = self.data[idx]["choices"][ans_index]
        
        input_text = f"<|USER|> {user} <|ASSISTANT|> {assistant} <|endoftext|>"
        target_text = f"<|USER|> {user} <|ASSISTANT|> {assistant} <|endoftext|>"
        input_ids = self.tokenizer.encode(input_text, add_special_tokens=True, max_length=self.max_length, truncation=True)
        target_ids = self.tokenizer.encode(target_text, add_special_tokens=True, max_length=self.max_length, truncation=True)
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
        target_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(target_ids))

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.int64, device=self.device),
            "labels": torch.tensor(target_ids, dtype=torch.int64, device=self.device),
        }
class CompletionDataset(Dataset):
    def __init__(self, tokenizer, data, max_length=256, device="cuda"):
        self.data = data
        self.device = device
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        input_text = f"{text} {self.tokenizer.eos_token}"
        target_text = f"{text} {self.tokenizer.eos_token}"
        input_ids = self.tokenizer.encode(input_text, add_special_tokens=True, max_length=self.max_length, truncation=True)
        target_ids = self.tokenizer.encode(target_text, add_special_tokens=True, max_length=self.max_length, truncation=True)
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
        target_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(target_ids))
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.int64, device=self.device),
            "labels": torch.tensor(target_ids, dtype=torch.int64, device=self.device),
        }

# Training loop

We'll define the training loop, with a few metrics to keep track of the model's learning.

In [None]:
def train(model, dataloader, optimizer, tokenizer, device="cuda"):
    model.train()
    model.to(device=device)
    total_loss = 0
    total_perplexity = 0

    for i, batch in tqdm(enumerate(dataloader)):
                input_ids = batch["input_ids"].to(device)
                target_ids = batch["labels"].to(device)
                target_text = batch["target_text"]
                # Generate the output and calculate the loss
                outputs = model(input_ids=input_ids, labels=target_ids)
                loss, logits = outputs[:2]
                # Calculate the BLEU score
                probs = F.softmax(logits, dim=-1)
                predictions = torch.argmax(probs, dim=-1)
                predictions_str = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions.tolist()]
                target_ids_str = [tokenizer.decode(tgt, skip_special_tokens=True) for tgt in target_ids.tolist()]
                print(predictions_str[0])
                bleu_scores = []
                accuracy_scores = []
                for pred_str, target_str in zip(predictions_str, target_ids_str):
                    bleu = sacrebleu.sentence_bleu(pred_str, [target_str])
                    bleu_scores.append(bleu.score)
                for pred_id, target_id in zip(predictions, target_ids):
                    accuracy = SENTIAForCausalLM.calculate_accuracy(pred_id, target_id)
                    accuracy_scores.append(accuracy)

                accuracy = sum(accuracy_scores) / len(accuracy_scores)
                bleu = sum(bleu_scores) / len(bleu_scores)
                # Calculate the reward
                # This reward can be used for RLHF, but I prefer using it as a metric.
                # The highest value is typically around 2x the sequence length.
                # The lowest value is typically about the negative of the sequence length.
                reward, penalty = SENTIAForCausalLM.get_reward(predictions.tolist()[0], target_ids.tolist()[0], bleu)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # Update the metrics
                total_loss += loss.item()
                try:
                    wandb.log({"loss": ol.item(), "bleu": bleu, "perplexity": torch.exp(ol).item(), "accuracy": accuracy})
                except:
                    pass
                print(
                    f"Batch {i + 1}/{len(dataloader)}: Loss - {loss.item():.4f}, NetReward - {reward - penalty:.4f}, BLEU - {bleu:.4f}, Perplexity - {torch.exp(loss).item()}, Accuracy - {accuracy}")

    return total_loss / len(dataloader)

# Evaluation loop

We'll define an evaluation loop with multiple metrics to track the model's performance. We'll include scores like loss, perplexity, bleu, rouge, and f1.

In [None]:
def evaluate(model, val_loader, tokenizer, use_cuda=True):
    model.eval()
    device = torch.device('cuda' if use_cuda and torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Load metrics
    bleu_metric = load_metric('bleu')
    rouge_metric = load_metric('rouge')
    
    # Initialize variables to accumulate scores
    total_loss = 0
    all_predictions = []
    all_references = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            # Move batch to the correct device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            # Forward pass
            batch.pop("target_text")
            outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
            loss = outputs.loss
            total_loss += loss.item()
            
            # Convert logits to predictions (for F1, BLEU, ROUGE)
            # This part depends on your model's output format and the task
            # Here is a mock-up of how you might extract predictions
            # For token classification tasks:
            # predictions = outputs.logits.argmax(dim=-1)
            # For seq2seq tasks:
            predictions = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)

            # Post-process batch to extract labels and predictions in a suitable format
            references = batch['labels'] 
            references = tokenizer.batch_decode(references, skip_special_tokens=True)
            
            # Update metrics
            references = [[ref] for ref in references]
            bleu_metric.add_batch(predictions=predictions, references=references)
            rouge_metric.add_batch(predictions=predictions, references=references)
            # Store predictions and references for later use if needed
            all_predictions.extend(predictions)
            all_references.extend(references)
    # Compute the metrics
    bleu_score = bleu_metric.compute(predictions=all_predictions, references=all_references)
    rouge_score = rouge_metric.compute(predictions=all_predictions, references=all_references)

    # Perplexity can be calculated from the total loss
    # For perplexity, we assume the loss is the negative log likelihood
    # In case the loss function is something else, this needs to be adjusted
    perplexity = torch.exp(torch.tensor(total_loss / len(val_loader)))

    metrics = {
        'val_loss': total_loss / len(val_loader),
        'val_perplexity': perplexity.item(),
        'val_bleu': bleu_score['bleu'],
        'val_rouge': rouge_score,
    }
    try:
        wandb.log(**metrics)
    except:
        pass

    return metrics

# TrainArgs

Here we'll define a class creates a config for the training.

In [None]:
@dataclass
class TrainArgs:
    # Model configuration
    model: str = field(default="Locutusque/TinyMistral-248M")  # Pretrained model name or path
    batch_size: int = field(default=8)
    num_epochs: int = field(default=3)
    learning_rate: float = field(default=5e-5)
    device: str = field(default="cuda" if torch.cuda.is_available() else "cpu")

    # Data loading
    dataset: str = field(default="Skylion007/openwebtext")
    datasetconfig: Optional[str] = field(default=None)  # Configuration for the dataset if required
    split: str = field(default="train")
    val_dataset: str = field(default="Skylion007/openwebtext")
    val_datasetconfig: Optional[str] = field(default=None)
    val_split: str = field(default="validation")

    # Training output
    save_dir: str = field(default="./saved_models")

    # Data type (optional, depends on your use case)
    dtype: str = field(default="float32")

# Define the main function

This is where the loop will go, and the model will be loaded and trained.

In [None]:
def main(args: TrainArgs):
    try:
        del model
        
    except:
        pass
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.add_special_tokens({"additional_special_tokens": ["<|USER|>", "<|ASSISTANT|>"]})
    train_data = load_dataset(args.dataset, args.datasetconfig, split=args.split)
    val = load_dataset(args.val_dataset, args.val_datasetconfig, split=args.val_split)
    dtype = STRING_TO_DTYPE_MAPPING.get(args.dtype)
    # Uncomment this if you want to use wandb.
    #wandb.init(dir="", project="")
    train_data = ConversationDataset(tokenizer, data=train_data, max_length=256)
    val_data = ConversationDataset(tokenizer, data=val, max_length=256)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=args.batch_size)

    # Initialize the model
    model = AutoModelForCausalLM.from_pretrained(args.model)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device, dtype=dtype)
    
    # Define the optimizer
    optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate)
    
    # Training and evaluation loops
    try:
        for epoch in range(args.num_epochs):
            print(f'Epoch: {epoch+1:02}')
            train_loss = train(model, train_loader, optimizer, tokenizer, args.device)
            val_metrics = evaluate(model, val_loader, tokenizer)

            print(f'Epoch: {epoch+1:02}')
            print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
            print(f'\tValidation metrics: {val_metrics}')

            # Calculate and display BLEU, accuracy, and any other desired metrics
            # You'll need to implement this part based on your specific task
    except KeyboardInterrupt:
         print("Saving and cleaning up the model...")
         print("Do NOT kill the terminal it WILL corrupt the model files")
         model.save(args.save_dir)
         quit(0)

In [None]:
args = TrainArgs(
    split="train[:100]",
    val_split="train[100:200]",
    dataset="Locutusque/InstructMix",
    val_dataset="Locutusque/InstructMix"
)
    
main(args)

# Main loop

This is where we'll run the main function, and start the training process.