In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from multi_stage import LLM
from prep_data import get_eng_hi_dataset
from transformers import GPT2LMHeadModel, MT5Tokenizer

device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
device
# device = torch.device('cpu')

# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True

device(type='cuda', index=4)

<h2>Obtain parallel data (EN-HI)</h2>
<h5>Data is in the form of dictionary with 'en' and 'hi' keys corresponding to english and hindi sentences respectively</h5>

In [2]:
train_data, val_data, test_data = get_eng_hi_dataset()
train_data = train_data[40000:]

Found cached dataset parquet (/root/.cache/huggingface/datasets/cfilt___parquet/cfilt--iitb-english-hindi-911387c6837f8b91/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
len(train_data), len(val_data), len(test_data)

(1611623, 520, 2507)

In [4]:
class ParallelCorpus(Dataset):
    def __init__(self, data, src_lang='en', tgt_lang='hi') -> None:
        super(ParallelCorpus, self).__init__()
        self.src = []
        self.tgt = []
        for pair in data:
            self.src.append(pair[src_lang])
            self.tgt.append(pair[tgt_lang])
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        return self.src[index], self.tgt[index]

train_pc = ParallelCorpus(train_data, src_lang='en', tgt_lang='hi')
test_pc = ParallelCorpus(test_data, src_lang='en', tgt_lang='hi')
val_pc = ParallelCorpus(val_data, src_lang='en', tgt_lang='hi')

<h2>Hyperparameters</h2>

In [5]:
len_prefix = 100
lr = 2e-4
beta1 = 0.9
beta2 = 0.98
batch_size = 8
num_epochs = 1
token_limit = (1023 - len_prefix) // 2

In [6]:
train_loader = DataLoader(dataset=train_pc, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(dataset=val_pc, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_pc, batch_size=1, shuffle=False)
len(train_loader)

201453

In [7]:
tokenizer = MT5Tokenizer.from_pretrained("THUMT/mGPT")
model = GPT2LMHeadModel.from_pretrained("THUMT/mGPT")
for param in model.parameters():
    param.requires_grad_(False)

MT_model = LLM(model, len_prefix).to(device)
optimizer = torch.optim.Adam(params=MT_model.prefix.parameters(),lr=lr, betas=(beta1, beta2), eps=1e-9)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'T5Tokenizer'.


<h2>Training</h2>

In [None]:
import time

@torch.no_grad()
def validation():
    total_loss = 0
    for i, (src, tgt) in enumerate(val_loader):
        max_src_len = min(token_limit, max([len(s) for s in src]))
        max_tgt_len = min(token_limit, max([len(s) for s in tgt]))
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        total_loss += loss
    return total_loss / len(val_loader)
        

min_val_loss = 10000
PATH = 'saved_models/prefix.pt'
for epoch in range(num_epochs):
    print(f"------------------------EPOCH {epoch + 1}-------------------------------")
    t1 = time.time()
    for i, (src, tgt) in enumerate(train_loader):
        MT_model.zero_grad()
        
        max_src_len = min(token_limit, max([len(s) for s in src]))
        max_tgt_len = min(token_limit, max([len(s) for s in tgt]))
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
#         print(len(input_ids[0]))
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
#         print(len(input_ids[0]))
#         print(MT_model._model.config.max_position_embeddings)
#         print(tgt[0], target_ids[0])
        
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        loss.backward()
        optimizer.step()
        if (i+1)%250 == 0:
            t2 = time.time()
            val_loss = validation()
            if val_loss.item() < min_val_loss:
                torch.save(MT_model.prefix.state_dict(), PATH)
                min_val_loss = val_loss
            print(f'Step {i+1} | Val Loss: {val_loss.item():.5f}| Best val loss: {min_val_loss:.5f} | Time: {(t2-t1)/3600 : .4f} hrs')

------------------------EPOCH 1-------------------------------
Step 250 | Val Loss: 3.84863| Best val loss: 3.84863 | Time:  0.1632 hrs


In [None]:
PATH = "saved_models/prefix.pt"
MT_model.prefix.load_state_dict(torch.load(PATH))

In [None]:
@torch.no_grad()
def validation():
    total_loss = 0
    for i, (src, tgt) in enumerate(val_loader):
        if (i+1)%100 == 0:
            print(f"Processing {i+1}....")
        max_src_len = min(token_limit, max([len(s) for s in src]))
        max_tgt_len = min(token_limit, max([len(s) for s in tgt]))
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        total_loss += loss
    return total_loss / len(val_loader)

validation()

In [None]:
@torch.no_grad()
def greedy_translate(model, device, tokenizer, input_sent):
    tok_output = tokenizer(input_sent)
    input_ids = tok_output['input_ids']
    input_mask = tok_output['attention_mask']
    input_ids = torch.tensor(input_ids, device=device)
    input_mask = torch.tensor(input_mask, device=device)
    target_mask = torch.ones([1, 1], device=device)
    past_key_values = model.encode(input_ids, input_mask)
    print(past_key_values[0][0].shape)
    start = [1]
    gen = []
    curr_token = None
    while curr_token != 1:
        tgt = torch.tensor(start, device=device)    
        logits, past_key_values = model.decode(tgt, input_mask, target_mask, past_key_values)
        print(past_key_values[0][0].shape)
        logits = model.logSoftmax(logits.unsqueeze(0)).squeeze(0)
        value, index = torch.max(logits, dim=1)
        curr_token = index[0].item()
        gen.append(curr_token)
        start = [curr_token]
        print(curr_token, value.item())
        target_mask = torch.cat([target_mask, torch.ones([1, 1], device=device)], dim=1)
    output_sent = tokenizer.decode(gen)
    return output_sent

sent = ['is it okay to have a gold necklace']
greedy_translate(MT_model, device, tokenizer, sent)

In [None]:
tokenizer('</s>')['input_ids']