In [None]:
!pip install transformers
!pip install datasets
!pip install matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import math
from transformers import AutoTokenizer
from torch.cuda import amp
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
# Hyperparameters
max_token = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 10
batch_size = 120
validation_batch_size = 10
weight_decay = 1e-3
lr = 1e-3
num_layer = 1
head_dim = 64
projection_dim = None
expansion_factor = 1
checkpoint_filepath = None

# Load llama3 token embedding

In [None]:
# Load from pt file (If you had already preprocessed)
word_embeddings_tensor = torch.load('word_embeddings_tensor.pt')
num_embeddings, embedding_dim = word_embeddings_tensor.shape
word_embeddings_tensor.requires_grad = False

model_id = "NousResearch/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
tokenizer.pad_token_id = 128002

# Load training data

In [None]:
# Load training data
tensor = torch.load("llama3_wiki_64.pt")

total_data_num = tensor.shape[0]
training_data_num = int(total_data_num * 0.98)

training_data = tensor[:training_data_num]
validation_data = tensor[training_data_num:]

# Create a TensorDataset
training_data = TensorDataset(training_data)
validation_data = TensorDataset(validation_data)

# Use DataLoader for batching, etc.
training_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=validation_batch_size, shuffle=True)

# Free up memory
del tensor

# Instantiate LLM

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, wandb: bool = False):
        super().__init__()
        self.sqrt_dim: float = 1 / math.sqrt(dim)
        self.eps: float = eps
        self.wandb: bool = wandb
        if wandb:
            self.scale = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim))

    def find_rms_value(self, tensor: torch.Tensor) -> float:
        norm_2 = tensor.norm(2, dim=-1)
        return norm_2 * self.sqrt_dim

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        tensor = tensor.float() # Using 4 bit float for stability
        rms: float = self.find_rms_value(tensor)
        tensor = tensor/(rms.unsqueeze(-1) + self.eps)

        if self.wandb:
            tensor = tensor * self.scale
            tensor = tensor + self.bias

        return tensor


class ROPEEmbedding(nn.Module):
    def __init__(self, max_token: int, dim: int, theta: int):
        super().__init__()
        self.pos_emb = self.create_embedding(max_token, dim, theta)

    def create_embedding(self, max_token: int, dim: int, theta: int) -> torch.Tensor:
        tensor = torch.arange(0, dim // 2)
        tensor = torch.repeat_interleave(tensor, 2)
        tensor = -tensor * 2 / dim
        tensor = torch.pow(theta, tensor)

        index = torch.arange(max_token).float() # This is the m in the formula
        tensor = torch.einsum("i, j -> ij", tensor, index)

        cos_matrix = tensor.cos()
        sin_matrix = tensor.sin()
        sin_matrix[0::2] *= -1 # Flipping sign for 0, 2, 4... row of sin matrix

        pos_emb = torch.cat((cos_matrix, sin_matrix), dim=0)
        pos_emb = pos_emb.transpose(1, 0)
        pos_emb = nn.Parameter(pos_emb, requires_grad=False)

        return pos_emb

    def flip_for_sin(self, tensor: torch.Tensor) -> torch.Tensor:
        original_shape = tensor.shape
        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, 2) # Get to pairs
        tensor = tensor[..., [1, 0]] # Swap
        tensor = tensor.reshape(original_shape) # Get back to original shape
        return tensor

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        sequence_length = tensor.shape[2] # Assuming we are using batch_size, head, sequence_length and dim

        tensor = torch.cat((tensor, self.flip_for_sin(tensor)), dim=-1)
        tensor = tensor * self.pos_emb[:sequence_length, :]
        cos, sin = tensor.chunk(chunks=2, dim=-1)
        tensor = cos + sin
        return tensor


class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, head_dim: int, q_head: int, kv_head: int, embedding: ROPEEmbedding):
        super().__init__()
        self.head_dim = head_dim
        self.q_head = q_head
        self.kv_head = kv_head
        self.embedding = embedding
        self.qkv = nn.Linear(hidden_dim, (q_head+kv_head*2)*head_dim)
        self.o = nn.Linear(q_head*head_dim, hidden_dim)
        self.scaler = 1/math.sqrt(head_dim)

        if q_head != kv_head:
            # If we are using multi query attention
            assert q_head % kv_head == 0
            self.multi_query_attention = True
            self.q_kv_scale = q_head//kv_head
        else:
            self.multi_query_attention = False

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, hid_dim = tensor.shape

        tensor = self.qkv(tensor)
        query, key, value = tensor.split([self.head_dim*self.q_head, self.head_dim*self.kv_head, self.head_dim*self.kv_head], dim=-1)

        query = query.view(batch_size, seq_len, self.q_head, self.head_dim)
        key = key.view(batch_size, seq_len, self.kv_head, self.head_dim)
        value = value.view(batch_size, seq_len, self.kv_head, self.head_dim)

        if self.multi_query_attention:
            # If we are using multi query attention, duplicate key value heads
            key = torch.repeat_interleave(key, self.q_kv_scale, dim=-2)
            value = torch.repeat_interleave(value, self.q_kv_scale, dim=-2)

        # Switch to batch_size, head, seq_len, head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # Apply ROPE
        query = self.embedding(query)
        key = self.embedding(key)
        
        # Classic self attention
        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scaler
        if attention_mask != None:
            attention_scaled += attention_mask
        attention_score = torch.softmax(attention_scaled, dim=-1)
        value = torch.matmul(attention_score, value)

        # Reshape back to batch_size, seq_len, hid_dim
        value = value.transpose(1, 2).contiguous()
        value = value.view(batch_size, seq_len, hid_dim)

        # Output layer
        output = self.o(value)

        return output

class FeedForward(nn.Module):
    def __init__(self, hidden_size: int, inner_size: int, dropout_ratio: float = 0.5):
        super().__init__()
        self.gate_and_up = nn.Linear(hidden_size, inner_size * 2)
        self.down = nn.Linear(inner_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout_ratio)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        tensor = self.gate_and_up(tensor)
        gate, up = tensor.chunk(chunks=2, dim=-1)
        gate = F.gelu(gate, approximate="tanh")
        up = self.dropout(up)
        tensor = gate * up
        tensor = self.down(tensor)
        return tensor


class GemmaLayer(nn.Module):
    def __init__(self, hidden_dim: int, inner_size: int, head_dim: int, q_head: int, kv_head: int, embedding: ROPEEmbedding, dropout_ratio: float = 0.5):
        super().__init__()
        self.norm1 = RMSNorm(hidden_dim)
        self.mqa = MultiQueryAttention(hidden_dim, head_dim, q_head, kv_head, embedding)

        self.norm2 = RMSNorm(hidden_dim)
        self.ffn = FeedForward(hidden_dim, inner_size, dropout_ratio)

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None):
        skip_connection = tensor
        tensor = self.norm1(tensor)
        tensor = self.mqa(tensor, attention_mask)
        tensor += skip_connection

        skip_connection = tensor
        tensor = self.norm2(tensor)
        tensor = self.ffn(tensor)
        tensor += skip_connection

        return tensor

class Gemma(nn.Module):
    def __init__(self, num_layer: int, vocab_size: int, max_token: int, hidden_dim: int, inner_size: int, head_dim: int, q_head: int | None = None, kv_head: int | None = None, dropout_ratio: float = 0.5, theta: int = 10000, projection_dim: int | None = None):
        super().__init__()
        self.embedding = ROPEEmbedding(max_token, head_dim, theta)
        self.num_layer = num_layer

        # Because of computational power limit, we might want to project input token embedding down.
        if projection_dim != None:
            self.projection = True
            self.projection_matrix = nn.Linear(hidden_dim, projection_dim)
            hidden_dim = projection_dim
        else:
            self.projection = False

        if q_head == None:
            q_head = (hidden_dim // head_dim)

        if kv_head == None:
            kv_head = (hidden_dim // head_dim)

        if hidden_dim % (head_dim * q_head) != 0 or hidden_dim % (head_dim * kv_head):
            raise ValueError("Error: hidden_dim or projection_dim (if specified) must be divisible by the product of the number of q or kv heads and the head dimension.")

        self.transformer = nn.ModuleList()
        for _ in range(self.num_layer):
            self.transformer.append(GemmaLayer(hidden_dim, inner_size, head_dim, q_head, kv_head, self.embedding, dropout_ratio))
        self.output_norm = RMSNorm(hidden_dim)

        self.classifier = nn.Linear(hidden_dim, vocab_size)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # If projecting input embeddings
        if self.projection:
            tensor = self.projection_matrix(tensor)
        
        seq_len = tensor.shape[1]
        causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(device)
        for layer in self.transformer:
            tensor = layer(tensor, causal_mask)

        tensor = self.output_norm(tensor)

        # Classification
        tensor = self.classifier(tensor)
        return tensor

In [None]:
gemma = Gemma(num_layer, num_embeddings, max_token, embedding_dim, embedding_dim*expansion_factor, head_dim, projection_dim=projection_dim).to(device)
print("This model has", sum(p.numel() for p in gemma.parameters()), "parameters.")
scaler = amp.GradScaler()

# Prepare for training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(gemma.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)

In [None]:
loss_train = []
loss_valid = []

In [None]:
def save_checkpoint(model, optimizer, epoch, loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    
    timestamp = time.strftime('%Y%m%d_%H%M%S')
    filename = f'checkpoint_{epoch}_{timestamp}.pth.tar'
    torch.save(checkpoint, filename)
    print(f'Checkpoint saved at epoch {epoch} as {filename}')

In [None]:
def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f'Checkpoint loaded from epoch {epoch} with loss {loss}')

if checkpoint_filepath != None:
    load_checkpoint(gemma, optimizer, checkpoint_filepath)

# Training code

In [None]:
for epoch in range(epochs):
    loss_train_epoch = []
    loss_val_epoch = []
    
    gemma.train()
    for data in tqdm(training_loader):
        # Teacher forcing
        input_data = data[0][:, :-1].to(device)
        target_data = data[0][:, 1:].to(device)

        # Convert to embedding.
        input_embeddings = word_embeddings_tensor[input_data]

        # Forward pass
        with amp.autocast():
            prediction = gemma(input_embeddings)

            # Change shape for loss calculation
            prediction = prediction.view(-1, num_embeddings)
            target_data = target_data.reshape(-1)

            mask = target_data != tokenizer.pad_token_id
            prediction = prediction[mask]
            target_data = target_data[mask]

            loss = criterion(prediction, target_data) # Calculate loss

        # Backward pass
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(gemma.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        # Clear out grad
        optimizer.zero_grad()

        # Record loss
        loss_train_epoch.append(loss.item())

    loss_train.append(np.mean(loss_train_epoch))

    gemma.eval()
    with torch.no_grad():
        for data in tqdm(validation_loader):
            # Teacher forcing
            input_data = data[0][:, :-1].to(device)
            target_data = data[0][:, 1:].to(device)
    
            # Convert to embedding.
            input_embeddings = word_embeddings_tensor[input_data]
    
            # Forward pass
            with amp.autocast():
                prediction = gemma(input_embeddings)
    
                # Change shape for loss calculation
                prediction = prediction.view(-1, num_embeddings)
                target_data = target_data.reshape(-1)

                mask = target_data != tokenizer.pad_token_id
                prediction = prediction[mask]
                target_data = target_data[mask]
                    
                loss = criterion(prediction, target_data) # Calculate loss
    
            # Record loss
            loss_val_epoch.append(loss.item())
    
        loss_valid.append(np.mean(loss_val_epoch))

    # Save checkpoint
    save_checkpoint(gemma, optimizer, epoch, loss_valid[-1])

    scheduler.step()

    plt.plot(loss_train, label="Training loss")
    plt.plot(loss_valid, label="Validation loss")
    print("Training loss: ", loss_train[-1])
    print("Validation loss: ", loss_valid[-1])
    plt.legend()
    plt.show()

# Inference

In [None]:
temperature = 0.1

In [None]:
sentence = "Apple is a company "
tokenized_sentence = tokenizer(sentence)["input_ids"]
if tokenized_sentence[-1] == 2:
    tokenized_sentence = tokenized_sentence[:-1]
gemma.eval()

with torch.no_grad():
    while(tokenized_sentence[-1] != tokenizer.eos_token_id and len(tokenized_sentence) < max_token): # Keep iterating until reaches end of sentence or max token limit
        # Preparing input
        tokenized_sentence_tensor = torch.tensor(tokenized_sentence)
        sentence_embedding = word_embeddings_tensor[tokenized_sentence_tensor]
        sentence_embedding = sentence_embedding.unsqueeze(0).to(device)

        # Make prediction
        with amp.autocast():
            prediction = gemma(sentence_embedding)
        prediction = prediction[0][-1] # We only care about last token
        prediction = prediction / temperature
        prediction = F.softmax(prediction, dim=-1)
        output_token = torch.multinomial(prediction, 1)

        # Append to conversation history
        tokenized_sentence.append(output_token.item())

tokens = tokenizer.decode(tokenized_sentence, skip_special_tokens=True)
print(tokens)

# Save the model

In [None]:
torch.save(gemma, 'gemma3point43.pth')