In [55]:
from datasets import load_dataset
from transformers import GPT2Tokenizer,GPT2LMHeadModel
from torch.utils.data import Dataset,Subset,DataLoader
import torch
import torch.nn as nn
from torch.optim import Adam
import math,tqdm

In [56]:
dataset = load_dataset("flytech/python-codes-25k")

In [57]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
separator_token = "<|sep|>"
tokenizer.add_tokens([separator_token])
tokenizer.pad_token="<|endoftext|>"
tokenizer.sep_token="<|sep|>"

In [58]:
print("Special Tokens:")
print("Pad Token:", tokenizer.pad_token)
print("EOS Token:", tokenizer.eos_token)
print("BOS Token:", tokenizer.bos_token)
print("UNK Token:", tokenizer.unk_token)
print("Sep Token:", tokenizer.sep_token)
print("CLS Token:", tokenizer.cls_token)
print("Mask Token:", tokenizer.mask_token)

Special Tokens:
Pad Token: <|endoftext|>
EOS Token: <|endoftext|>
BOS Token: <|endoftext|>
UNK Token: <|endoftext|>
Sep Token: <|sep|>
CLS Token: None
Mask Token: None


In [59]:
train_n=1000
test_n=200
train_data=Subset(dataset['train'],[i for i in range(train_n)])
test_data=Subset(dataset['train'],[i+train_n for i in range(test_n)])

In [60]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [61]:
def concatenate_and_tokenize(examples):
    concatenated = examples['instruction'] + " <|sep|> " + examples['input'] + " <|sep|> " + examples['output'] + " <|endoftext|>"
    return tokenizer(concatenated, padding="max_length", truncation=True, max_length=512)




class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        tokens_dict=concatenate_and_tokenize(self.dataset[idx])
        item = {key: torch.tensor(val).to(device) for key, val in tokens_dict.items()}
        return item

train_dataset = CustomDataset(train_data)
test_dataset = CustomDataset(test_data)

train_loader= DataLoader(train_dataset,batch_size=8,shuffle=True)
test_loader= DataLoader(test_dataset,batch_size=8,shuffle=False)

In [62]:
# train_dataset[0]

In [63]:
class LoraConv1D(nn.Module):
    def __init__(self, c_attn_layer, rank=10):
        super().__init__()
        self.c_attn = c_attn_layer
        self.rank = rank

        self.A = nn.Parameter(torch.randn(self.c_attn.weight.shape[0], rank))
        self.B = nn.Parameter(torch.randn(rank, self.c_attn.weight.shape[1]))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.c_attn.nf,)
        delta_W = self.A @ self.B
        adapted_weight = self.c_attn.weight + delta_W
        x = torch.addmm(self.c_attn.bias, x.view(-1, x.size(-1)), adapted_weight)
        x = x.view(size_out)
        return x


In [64]:
model = GPT2LMHeadModel.from_pretrained("gpt2")
print(model.transformer.wte.weight.data.shape)
print(model.transformer.wte.weight.data[0][0])
print(model.transformer.wte.weight.data[-1][0])
model.resize_token_embeddings(len(tokenizer))
new_embedding_vector = torch.zeros((1, model.config.hidden_size))
model.transformer.wte.weight.data[-1, :] = new_embedding_vector
print(model.transformer.wte.weight.data.shape)
print(model.transformer.wte.weight.data[0][0])
print(model.transformer.wte.weight.data[-1][0])




for param in model.parameters():
    param.requires_grad = False
for index in [10,11]:
    model.transformer.h[index].attn.c_attn = LoraConv1D(model.transformer.h[index].attn.c_attn, rank=10)

model.to(device)

torch.Size([50257, 768])
tensor(-0.1101)
tensor(0.0514)
torch.Size([50258, 768])
tensor(-0.1101)
tensor(0.)


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50258, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-9): 10 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10-11): 2 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): LoraConv1D(
            (c_attn): Conv1D()
          )
          (c_proj): Conv1D()
          (attn_dropout):

In [65]:
# model

In [66]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-5)

In [67]:
def evaluate_perplexity(model=model, test_loader=test_loader, tokenizer=tokenizer):
    model.eval()
    total_loss = 0
    total_words = 0

    with torch.no_grad():
        for batch in test_loader:
            inputs = batch
            outputs = model(**inputs, labels=inputs["input_ids"])
            total_loss += outputs.loss.item() * inputs["input_ids"].size(1)
            total_words += inputs["input_ids"].size(1)

    average_loss = total_loss / total_words
    perplexity = math.exp(average_loss)
    return perplexity

In [68]:
evaluate_perplexity(model,test_loader,tokenizer)

149984020.1826658

In [69]:
def train(model=model,train_loader=train_loader,test_loader=test_loader,
          tokenizer=tokenizer,optimizer=optimizer,loss_function=loss_function, 
          epochs=5):
    for epoch in range(epochs):
        total_loss = 0
        model.train()
        for batch in tqdm.tqdm(train_loader):
            outputs=model(**batch,labels=batch['input_ids'])
            loss=outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        average_loss=total_loss/len(train_loader)
        print(f"EPOCH: {epoch}, loss: {average_loss}, test_perplexity: {evaluate_perplexity()}")


In [70]:
train()

100%|██████████| 125/125 [04:07<00:00,  1.98s/it]


EPOCH: 0, loss: 19.478653106689453, test_perplexity: 48175164.87326563


100%|██████████| 125/125 [04:51<00:00,  2.34s/it]


EPOCH: 1, loss: 18.478799995422364, test_perplexity: 21615301.18786448


100%|██████████| 125/125 [04:41<00:00,  2.26s/it]


EPOCH: 2, loss: 17.566429870605468, test_perplexity: 9585959.299039198


100%|██████████| 125/125 [04:44<00:00,  2.27s/it]


EPOCH: 3, loss: 16.992934974670412, test_perplexity: 5429467.01775179


 83%|████████▎ | 104/125 [04:07<00:49,  2.38s/it]


KeyboardInterrupt: 