<a href="https://colab.research.google.com/github/aniervs/small-mistral-like-llm/blob/main/my_mistral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install tensorboard

In [None]:
%load_ext tensorboard

In [None]:
import sys
sys.version

In [None]:
! pip install datasets

In [None]:
! pip install einops

# Imports

In [None]:
import datasets
import torch
import torch.nn as nn
import einops
import numpy as np
from pydantic import BaseModel
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, DataProcessor, DataCollatorForLanguageModeling
from collections import Counter
from einops import reduce, repeat, rearrange, einsum
from torch.utils.tensorboard import SummaryWriter

In [None]:
def get_device_name():
    if torch.cuda.is_available():
        return 'cuda'
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return 'mps'
    return 'cpu'

device = torch.device(get_device_name())
device

# Data


In [None]:
dataset = datasets.load_dataset("HuggingFaceH4/no_robots")

In [None]:
train_dataset = dataset['train_sft']
test_dataset = dataset['test_sft']
train_dataset.shape, test_dataset.shape

In [None]:
train_dataset.features

In [None]:
train_dataset[0]

In [None]:
category_cnt = Counter(train_dataset['category'])
plt.pie(category_cnt.values(), labels = category_cnt.keys(), autopct='%1.1f%%')
plt.title("Distribution of Prompts' Categories")
plt.show()

## Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
def tokenize(datum: dict):
    prompt = "\n\n".join([message["content"] for message in datum['messages']])
    tokens = tokenizer.encode(
        text = prompt,
        truncation = True,
        padding = "max_length",
        max_length=2048,
        return_tensors = 'pt'
    )[0]
    return {'input_ids' : tokens}

In [None]:
print(tokenize(train_dataset[0]))

In [None]:
columns = train_dataset.column_names
train_tokens = train_dataset.map(tokenize, num_proc = 4, remove_columns = columns)
test_tokens = test_dataset.map(tokenize, num_proc = 4, remove_columns = columns)

train_tokens.shape

## Data Loaders

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_loader = DataLoader(train_tokens, shuffle = True, batch_size = 4, collate_fn = collator)
test_loader = DataLoader(test_tokens, shuffle = False, batch_size = 4, collate_fn = collator)

In [None]:
batch = next(iter(train_loader))

In [None]:
batch.keys()

In [None]:
batch['input_ids'][0][-10:]

In [None]:
batch['labels'][0][-10:]

In [None]:
batch['attention_mask']

# Model

In [None]:
tokenizer.vocab_size

In [None]:
class ModelArguments(BaseModel):
    dim: int = 1024 # originally, 4096
    n_layers: int = 2 # originally, 32 (two layers are enough to see induction heads)
    head_dim: int = 32 # originally 128
    hidden_dim: int = 256 # originally 7 * 2048
    n_heads: int = 4 # originally 32
    n_kv_heads: int = 2 # originally 8
    window_size: int = 1024 # originally 4096
    context_len: int = 2048 # originally 8192
    vocab_size: int = 32001

## Other parts of the model

In [None]:
class Embedding(nn.Module):
    def __init__(self, args: ModelArguments):
        super().__init__()
        self.embedding = nn.Embedding(args.vocab_size, args.dim)
        self.positional_encoding = nn.Embedding(args.context_len, args.dim)

    def forward(self, x):
        return self.embedding(x) + self.positional_encoding(torch.arange(x.shape[1], device = x.device))

class RMSNorm(nn.Module):
    def __init__(self, args: ModelArguments):
        super().__init__()
        self.g_val = nn.Parameter(torch.ones(args.dim))

    def forward(self, x):
        rms = torch.sqrt(reduce(x**2, 'batch token dim -> batch token 1', 'mean'))
        return x * self.g_val / rms

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.lin1 = nn.Linear(dim, hidden_dim)
        self.lin2 = nn.Linear(dim, hidden_dim)
        self.lin3 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        x = nn.functional.silu(self.lin1(x)) + self.lin2(x)
        return self.lin3(x)


## Attention

- **Grouped-Query Attention**: Same key and value for heads on the same group.
- **Sliding Window Attention**: Slides a window of certain size and apply full GQA on the tokens of that window. This composes over all layers, virtually increasing the effective context.

### Grouped-Query Attention

In [None]:
def create_mask(seq_len, window_size):
    idx = torch.arange(seq_len)
    diff = rearrange(idx, 'i->i 1') - rearrange(idx, 'i->1 i')
    mask = ((diff >= 0) & (diff < window_size))
    return mask.float()

In [None]:
class GroupQueryAttention(nn.Module):
    def __init__(self, args: ModelArguments):
        super().__init__()

        self.dim = args.dim
        self.head_dim = args.head_dim
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_kv_heads
        self.window_size = args.window_size

        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim)
        self.wk = nn.Linear(self.dim, self.n_heads * self.head_dim // self.n_kv_heads)
        self.wv = nn.Linear(self.dim, self.n_heads * self.head_dim // self.n_kv_heads)
        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim)

        self.rms_norm = RMSNorm(args)

    def forward(self, x):

        mask = create_mask(x.shape[1], self.window_size).to(device)

        q = rearrange(self.wq(x), 'batch token_query (head_query dim) -> head_query batch token_query dim', head_query = self.n_heads)
        k = repeat(self.wk(x), 'batch token_key (head_group dim) -> (group head_group) batch token_key dim', head_group = self.n_heads // self.n_kv_heads, group = self.n_kv_heads)
        v = repeat(self.wv(x), 'batch token_key (head_group dim) -> (group head_group) batch token_key dim', head_group = self.n_heads // self.n_kv_heads, group = self.n_kv_heads)

        att_score = einsum(q, k, 'head batch i dim, head batch j dim -> head batch i j') / np.sqrt(self.head_dim)
        att_score = att_score.masked_fill(mask == 0, -1e9)
        att_score = att_score.softmax(dim = -1)

        output = einsum(att_score, v, 'head batch i j, head batch j dim -> head batch i dim')
        output = rearrange(output, 'head batch token dim -> batch token (head dim)')
        output = self.wo(output)
        output = self.rms_norm(output) + x

        return output

## Model

In [None]:
class Mistral(nn.Module):
    def __init__(self, args: ModelArguments):
        super().__init__()

        self.n_layers = args.n_layers
        self.embedding = Embedding(args)

        self.att_layers = nn.ModuleList([GroupQueryAttention(args) for _ in range(self.n_layers)])

        self.ff = FeedForward(args.dim, args.hidden_dim)

        self.linear = nn.Linear(args.dim, args.vocab_size)

    def forward(self, x):
        x = self.embedding(x)

        for layer in self.att_layers:
            x = layer(x)

        x = x + self.ff(x)
        return self.linear(x)


In [None]:
model = Mistral(ModelArguments())
model

In [None]:
model = model.to(device)
print(batch['input_ids'].shape)
y = model(batch['input_ids'])
print(y.shape)

### Counting the number of parameters

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{round(total_params / 1e6)}M parameters")

In [None]:
y_true = batch['labels']
y_true.shape, y.shape

# Training

In [None]:
model = Mistral(ModelArguments())
model = model.to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)

In [None]:
%tensorboard --logdir=runs

In [None]:
def train(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    avg_loss = 0
    n_batch = 0
    for idx, batch in enumerate(train_loader):
        x, y = batch['input_ids'], batch['labels']
        x, y = x.to(device), y.to(device)

        y_pred = model(x)

        y_pred = rearrange(y_pred, 'batch token vocab -> (batch token) vocab')
        y = rearrange(y, 'batch token -> (batch token)')

        loss_value = loss_fn(y_pred, y)

        avg_loss += loss_value
        n_batch += 1

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

    return avg_loss / n_batch

def test(model, test_loader, loss_fn, epoch):

    model.eval()
    avg_loss = 0
    n_batch = 0
    with torch.no_grad():
        for idx, batch in enumerate(test_loader):
            x, y = batch['input_ids'], batch['labels']
            x, y = x.to(device), y.to(device)

            y_pred = model(x)

            y_pred = rearrange(y_pred, 'batch token vocab -> (batch token) vocab')
            y = rearrange(y, 'batch token -> (batch token)')

            loss_value = loss_fn(y_pred, y)

            avg_loss += loss_value
            n_batch += 1

    return avg_loss / n_batch

In [None]:
n_epochs = 5

In [None]:
writer = SummaryWriter()
for epoch in tqdm(range(n_epochs)):
    loss = train(model, train_loader, loss_fn, optimizer, epoch)
    writer.add_scalar("Training Loss", loss, epoch)
writer.close()

# Post training