In [None]:
!pip install datasets
!pip install transformers
!pip install nltk
!pip install matplotlib
!pip install torch
!pip install ipywidgets
!pip install huggingface_hub
!pip install accelerate

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import notebook_login
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
import numpy as np
import nltk
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
import json
from torch.cuda import amp
from huggingface_hub import notebook_login

nltk.download('punkt')

In [None]:
# Hyperparameters
max_token = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 100
batch_size = 400
validation_batch_size = 1
weight_decay = 1e-3
drop_out_rate = 0.5
lr = 1e-3
gamma = 0.8
num_layer = 4
gradient_accumulation_step = 1

# Load llama2 and use its tokenizer and word embedding

In [None]:
notebook_login()

In [None]:
# Download llama2 and its tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# All word embedding
embedding_layer = model.get_input_embeddings()

# Delete gpt2 becase we are no longer using it.
del model

# Store how many embedding and embedding dimension of our llama2
num_embeddings, embedding_dim = embedding_layer.weight.size()

# Create a padding embedding and initialize the padding embedding with zeros
padding_embedding = torch.nn.Embedding(1, embedding_dim)
padding_embedding.weight.data.zero_()
num_embeddings += 1

# Concatenate the new padding embedding with the existing word embeddings
word_embeddings_tensor = torch.cat([embedding_layer.weight, padding_embedding.weight], 0)
# word_embeddings_tensor = word_embeddings_tensor.to(device)

# Preprocessing

In [None]:
# Run this code if you want to do data preprocessing
tokenized_data = []
attention_data = []

# Download dataset
dataset = load_dataset("wikipedia", "20220301.en")
training_dataset = dataset["train"]

# Tokenize all training data and filter those longer than token limit
for i in tqdm(range(len(training_dataset))):
    text = training_dataset[i]["text"]
    sentences = nltk.sent_tokenize(text)
    sentences = [sentence+"</s>" for sentence in sentences]

    # Tokenize input
    tokenized_sentence = tokenizer(sentences, padding='max_length', max_length=max_token)
    input_ids = tokenized_sentence["input_ids"]
    attention_mask = tokenized_sentence["attention_mask"]

    # Filter those longer than max_token
    for j in range(len(input_ids)):
        if len(input_ids[j]) <= max_token:
            tokenized_data.append(input_ids[j])
            attention_data.append(attention_mask[j])

# Write into json
with open('tokenized_data.json', 'w') as file:
    # Write the JSON data
    json.dump(tokenized_data, file)

with open('attention_data.json', 'w') as file:
    # Write the JSON data
    json.dump(attention_data, file)

In [None]:
# Load from json (If you had already preprocessed
with open('tokenized_data.json', 'r') as file:
    # Load the data from the file
    tokenized_data = json.load(file)

with open('attention_data.json', 'r') as file:
    # Load the data from the file
    attention_data = json.load(file)

In [None]:
total_data_num = len(tokenized_data)
training_data_num = int(total_data_num * 0.95)

training_data = torch.tensor(tokenized_data[:training_data_num])
training_attention = torch.tensor(attention_data[:training_data_num])
validation_data = torch.tensor(tokenized_data[training_data_num:])
validation_attention = torch.tensor(attention_data[training_data_num:])

# Create a TensorDataset
training_data = TensorDataset(training_data, training_attention)
validation_data = TensorDataset(validation_data, validation_attention)

# Use DataLoader for batching, etc.
training_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=validation_batch_size, shuffle=True)

# Free up memory
del tokenized_data
del attention_data

# Create positional embedding

In [None]:
# Positional encoding
max_token_pos = max_token - 1
pos_matrix = torch.empty(max_token_pos, embedding_dim)
for i in range(max_token_pos):
    for j in range(0, embedding_dim, 2):
        pos_matrix[i, j] = np.sin(i/(10000**(j/embedding_dim)))
        if(j+1<embedding_dim):
            pos_matrix[i, j+1] = np.cos(i/(10000**(j/embedding_dim)))
pos_matrix = pos_matrix

# Instantiate LLM

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor * torch.sigmoid(tensor)

class SwiGLU(nn.Module):
    def __init__(self, embedding_dim: int):
        super().__init__()
        self.W = nn.Linear(embedding_dim, embedding_dim)
        self.V = nn.Linear(embedding_dim, embedding_dim)
        self.swish = Swish()

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        W = self.W(tensor)
        V = self.V(tensor)
        return self.swish(W) * V

# This is root mean square norm implementation by author
# I do not take any credit for this
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super().__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

class MyLlamaLayer(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int, expand_factor: int = 4):
        super().__init__()
        # Transformer layer
        self.rms_norm1 = RMSNorm(embedding_dim)
        self.multihead_attention = nn.MultiheadAttention(embedding_dim, num_heads=num_heads)
        self.rms_norm2 = RMSNorm(embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim*expand_factor),
            RMSNorm(embedding_dim*expand_factor),
            nn.ReLU(),
            nn.Dropout(drop_out_rate),
            RMSNorm(embedding_dim*expand_factor),
            nn.Linear(embedding_dim*expand_factor, embedding_dim),
            nn.Dropout(drop_out_rate),
        )

    def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # Reshape to follow [seq_length, batch_size, embedding_size]
        tensor = tensor.transpose(0, 1)

        # Actually go through the transformer layer
        tensor_skip = tensor
        tensor = self.rms_norm1(tensor)
        tensor_skip = tensor_skip + self.multihead_attention(tensor, tensor, tensor, attn_mask=mask)[0]
        tensor = self.rms_norm2(tensor_skip)
        tensor_skip = tensor_skip + self.mlp(tensor)
        return tensor_skip.transpose(0, 1)

class MyLlama(nn.Module):
    def __init__(self, embedding_dim: int, num_layer: int, num_heads: int = None):
        super().__init__()
        if num_heads == None: # Default is to use a head for every 64 values
            self.num_heads = int(embedding_dim/64)

        # Transformer
        self.transformer = nn.ModuleList()
        for i in range(num_layer):
            self.transformer.append(MyLlamaLayer(embedding_dim, self.num_heads))

        self.norm = RMSNorm(embedding_dim)

        # Classifier
        self.classifier = nn.Linear(embedding_dim, num_embeddings)

    def forward(self, tensor: torch.Tensor, padding_mask: torch.Tensor):
        # Creating padding mask
        batch_size, sequence_length = padding_mask.shape
        padding_mask = padding_mask.unsqueeze(1)  # [batch_size, 1, sequence_length]
        padding_mask = padding_mask.expand(batch_size, sequence_length, sequence_length)  # [batch_size, sequence_length, sequence_length]

        # Create attention masking before doing anything
        shape = (tensor.shape[0], tensor.shape[1], tensor.shape[1])
        causal_mask = torch.ones(shape, dtype=torch.int64).to(device)
        causal_mask = torch.tril(causal_mask)

        # Apply padding mask
        mask = causal_mask & padding_mask
        mask = torch.where(mask == 0, float('-inf'), mask)
        mask = mask.to(dtype=torch.float32)

        # Reshape to apply attention masking to each head
        batch_list = list(range(batch_size))
        indices = torch.tensor(batch_list).repeat_interleave(self.num_heads)
        mask = mask[indices]

        # Transformer
        for layer in self.transformer:
            tensor = layer(tensor, mask)

        tensor = self.norm(tensor)

        # Classifier
        return self.classifier(tensor)

In [None]:
llama = MyLlama(embedding_dim, num_layer).to(device)
print("This model has", sum(p.numel() for p in llama.parameters()), "parameters.")
scaler = amp.GradScaler()

# Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(llama.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

In [None]:
loss_train = []
loss_valid = []

In [None]:
for epoch in range(epochs):
    loss_train_epoch = []
    loss_val_epoch = []
    for batch_idx, data in enumerate(tqdm(training_loader)):
        # Clear out grad
        optimizer.zero_grad()
        
        # Teacher forcing
        input_data = data[0][:, :-1]
        target_data = data[0][:, 1:].to(device)
        padding_mask = data[1][:, :-1].to(device)

        # Convert to embedding.
        input_embeddings = word_embeddings_tensor[input_data]
        input_embeddings = input_embeddings + pos_matrix
        input_data = input_data.to(device)
        input_embeddings = input_embeddings.to(device)

        # Forward pass
        with amp.autocast():
            prediction = llama(input_embeddings, padding_mask)

            # Change shape for loss calculation
            prediction = prediction.view(-1, num_embeddings)
            target_data = target_data.reshape(-1)
            loss = criterion(prediction, target_data) # Calculate loss
            scaler.scale(loss/gradient_accumulation_step).backward()

        # Backward pass
        if (batch_idx + 1) % gradient_accumulation_step == 0 or (batch_idx + 1) == len(training_loader):
            scaler.step(optimizer)
            scaler.update()

        # Record loss
        loss_train_epoch.append(loss.item())

    loss_train.append(np.mean(loss_train_epoch))

    for data in tqdm(validation_loader):
        # Teacher forcing
        input_data = data[0][:, :-1]
        target_data = data[0][:, 1:].to(device)
        padding_mask = data[1][:, :-1].to(device)

        # Convert to embedding.
        input_embeddings = word_embeddings_tensor[input_data]
        input_embeddings = input_embeddings + pos_matrix
        input_data = input_data.to(device)
        input_embeddings = input_embeddings.to(device)

        # Forward pass
        with amp.autocast():
            prediction = llama(input_embeddings, padding_mask)

            # Change shape for loss calculation
            prediction = prediction.view(-1, num_embeddings)
            target_data = target_data.reshape(-1)
            loss = criterion(prediction, target_data) # Calculate loss

        # Record loss
        loss_val_epoch.append(loss.item())

    loss_valid.append(np.mean(loss_val_epoch))

    scheduler.step()

    if len(loss_train) >= 2:
        plt.plot(loss_train[1:], label="Training loss")
        plt.plot(loss_valid[1:], label="Validation loss")
        print("Training loss: ", loss_train[-1])
        print("Validation loss: ", loss_valid[-1])
    else:
        plt.plot(loss_train, label="Training loss")
        plt.plot(loss_valid, label="Validation loss")
    plt.legend()
    plt.show()

# Inference

In [None]:
temperature = 0.5

In [None]:
sentence = "An apple is a round, edible fruit produced by"
tokenized_sentence = tokenizer(sentence)["input_ids"]

with torch.no_grad():
    while(tokenized_sentence[-1] != tokenizer.eos_token_id and len(tokenized_sentence) < max_token): # Keep iterating until reaches end of sentence or max token limit
        # Preparing input
        tokenized_sentence_tensor = torch.tensor(tokenized_sentence)
        sentence_embedding = word_embeddings_tensor[tokenized_sentence_tensor] + pos_matrix[:len(tokenized_sentence_tensor)].unsqueeze(0)
        sentence_embedding = sentence_embedding.to(device)
        attention_padding = torch.ones(len(tokenized_sentence_tensor), dtype=torch.int64).unsqueeze(0).to(device)

        # Make prediction
        prediction = llama(sentence_embedding, attention_padding)
        prediction = prediction[0][-1] # We only care about last token
        prediction = prediction / temperature
        prediction = F.softmax(prediction, dim=-1)
        output_token = torch.multinomial(prediction, 1)

        # Append to conversation history
        tokenized_sentence.append(output_token.item())

tokens = tokenizer.decode(tokenized_sentence)
print(tokens)