In [None]:
## --LIBRARY IMPORTS--

import os 

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from model_architecture import GPTModel


# initializing dailymail dataset and device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_df = pd.read_csv(os.path.join('../dataset/', 'dailymail/train.csv'))
val_df = pd.read_csv(os.path.join('../dataset/', 'dailymail/validation.csv'))
test_df = pd.read_csv(os.path.join('../dataset/', 'dailymail/test.csv'))

print(device)
print(train_df.shape)
print(val_df.shape)
print(test_df.shape)


#initializing and loading model

model = GPTModel()

model_state = torch.load(os.path.join('../models/subset19/', 'mp_rank_00_model_states.pt'), map_location = device)

missing, unexpected = model.load_state_dict(model_state['module'], strict = False)

print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = model.to(device)

In [2]:
class DailyMailSet(Dataset):

    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        # get a single item
        item = self.df.iloc[idx]

        # return its text and highlight
        return {
            'text': item['text'],
            'highlights': item['highlights']
        }

In [3]:
train_dl = DataLoader(DailyMailSet(train_df), batch_size = 32, shuffle = True)

val_dl = DataLoader(DailyMailSet(val_df), batch_size = 32, shuffle = True)

test_dl = DataLoader(DailyMailSet(test_df), batch_size = 16, shuffle = True)

In [4]:
optimizer = model.optimizer

tokenizer = model.config.tokenizer

eos_token = model.eos_token_idx

pad_token = model.pad_token_idx

sos_token = tokenizer.convert_tokens_to_ids('<|startoftext|>')

# using unk token as context becuase I am an idiot and forgot to add a context token in the tokenizer

context_token = tokenizer.convert_tokens_to_ids('<|unknown|>') 

print("e, p, s, c : ", eos_token, pad_token, sos_token, context_token)

e, p, s, c :  50256 50258 50257 50259


In [None]:
# copy the original weights and biases into a dictionary to compare post SFT and make sure they don't get lost along the way

original_weights = {}

for name, param in model.named_parameters():

    original_weights[name] = param.clone().detach()
    print("copied layer: ", name,"of size: ", param.shape)


In [None]:
# create the class for LoRA parameterization of a general matrix like the vocab layer or the attention matrices

class LoRA_Parameterization(nn.Module):

    def __init__(self, in_features: int, out_features: int, rank: int = 1, alpha: int = 1, device: str = 'cpu'):
        
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        self.device = device

        # LoRA matrices are of type BA times scale factor where B and A are the intrinsic rank matrices and scaling is a normal scalar to maintain variance
        # dim(B) = [in_features x rank]
        # dim(A) = [rank x out_features]
        # scale factor = alpha / rank

        self.scale = self.alpha/self.rank

        self.lora_A = nn.Parameter(torch.zeros((self.rank, self.out_features)).to(self.device))

        # initialize LoRA B matrix as all zeros so that delta_W is zero at the start of training
        self.lora_B = nn.Parameter(torch.zeros((self.in_features, self.rank)).to(self.device))

        # initialize LoRA A matrix as Gaussian distribution with avg = 0, std = 1
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # intialize an extra enabled flag to turn LoRA on/off at demand
        self.enabled = True

    def forward(self, original_weights: torch.Tensor):

        if self.enabled:

            #return the LoRA matrices with added to the original weights for performing the forward calls
            # adding a view call to the multiplied and scaled matrix for extra safety measure

            return original_weights + (torch.matmul(self.lora_B, self.lora_A)*self.scale).view(original_weights.shape)
        
        else:

            return original_weights
        