In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from multi_stage import LLM
from prep_data import get_eng_hi_dataset, tokenize
from transformers import GPT2LMHeadModel, MT5Tokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

<h2>Obtain parallel data (EN-HI)</h2>
<h5>Data is in the form of dictionary with 'en' and 'hi' keys corresponding to english and hindi sentences respectively</h5>

In [2]:
train_data, test_data = get_eng_hi_dataset()

Found cached dataset parquet (/home1/tejomay/.cache/huggingface/datasets/cfilt___parquet/cfilt--iitb-english-hindi-911387c6837f8b91/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
class ParallelCorpus(Dataset):
    def __init__(self, data, src_lang='en', tgt_lang='hi') -> None:
        super(ParallelCorpus, self).__init__()
        self.src = []
        self.tgt = []
        for pair in data:
            self.src.append(pair[src_lang])
            self.tgt.append(pair[tgt_lang])
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        return self.src[index], self.tgt[index]

train_pc = ParallelCorpus(train_data, src_lang='en', tgt_lang='hi')
test_pc = ParallelCorpus(test_data, src_lang='en', tgt_lang='hi')

In [6]:
model = GPT2LMHeadModel.from_pretrained("THUMT/mGPT")

<h2>Hyperparameters</h2>

In [7]:
len_prefix = 20
lr = 0.001
beta1 = 0.9
beta2 = 0.98
batch_size = 256
num_epochs = 100

for param in model.parameters():
    param.requires_grad_(False)

In [8]:
train_loader = DataLoader(dataset=train_pc, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_pc, batch_size=batch_size, shuffle=False)

In [9]:
MT_model = LLM(model, len_prefix).to(device)
optimizer = torch.optim.Adam(params=MT_model.prefix.parameters(),lr=lr, betas=(beta1, beta2), eps=1e-9)


<h2>Training</h2>

In [16]:
count = 0
for epoch in range(num_epochs):
    print(f"------------------------EPOCH {epoch + 1}-------------------------------")
    for i, (src, tgt) in enumerate(train_loader):
        input_ids, input_masks = tokenize(src)
        target_ids, target_masks = tokenize(tgt)
        loss = MT_model.forward(input_ids, input_masks, target_ids, target_masks)
        loss.backward()
        optimizer.step()

------------------------EPOCH 1-------------------------------
256
256
256
------------------------EPOCH 2-------------------------------
256
256
256
------------------------EPOCH 3-------------------------------
256
256
256
------------------------EPOCH 4-------------------------------
256
256
256
------------------------EPOCH 5-------------------------------
256
256
256
------------------------EPOCH 6-------------------------------
256
256
256
------------------------EPOCH 7-------------------------------
256
256
256
------------------------EPOCH 8-------------------------------
256
256
256
------------------------EPOCH 9-------------------------------
256
256
256
------------------------EPOCH 10-------------------------------
256
256
256
------------------------EPOCH 11-------------------------------
256
256
256
------------------------EPOCH 12-------------------------------
256
256
256
------------------------EPOCH 13-------------------------------
256
256
256
---------------------

KeyboardInterrupt: 