In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from prep_data import get_eng_hi_dataset
from transformers import GPT2LMHeadModel, MT5Tokenizer

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=3)

In [2]:
class LLM(nn.Module):
    def __init__(self, model) -> None:
        super(LLM, self).__init__()
        self._model = model
        self.hidden_size = model.config.hidden_size
        self.embed_size = model.config.hidden_size
        self.n_layers = model.config.num_hidden_layers
        self.n_heads = model.config.num_attention_heads
        self.head_size = self.embed_size // self.n_heads

        self.logSoftmax = nn.LogSoftmax(dim=2)
        self.logSoftmax_1 = nn.LogSoftmax(dim=1)
        self.nll = nn.NLLLoss()


    def encode(self, input_ids, input_mask):
        batch_size = input_ids.shape[0]
        len_sent = input_ids.shape[1]
        attn_mask = input_mask

        outputs = self._model(input_ids, 
                              attention_mask=attn_mask,  
                              use_cache=True)

        #RE-ENCODING
        layer_prefix_list = []
        attn_mask = torch.cat([input_mask, input_mask], dim=1)

        outputs = self._model(input_ids, 
                              past_key_values=outputs.past_key_values, 
                              attention_mask=attn_mask, 
                              use_cache=True)
        
        past_key_values = []
        for (key, value) in outputs.past_key_values:
            k = key[:,:,len_sent:,:]
            v = value[:,:,len_sent:,:]
            past_key_values.append((k, v))

        return past_key_values
    
    def decode(self, target_ids, input_mask, target_mask, past_key_values, mode='train'):
        batch_size = target_ids.shape[0]
        attn_mask = torch.cat([input_mask, target_mask], dim=1)

        outputs = self._model(target_ids, 
                              past_key_values=past_key_values, 
                              attention_mask=attn_mask,  
                              use_cache=True)
        return outputs.logits, outputs.past_key_values
    
    def forward(self, input_ids, input_mask, target_ids, target_mask):
        past_key_values = self.encode(input_ids, input_mask)
        labels = target_ids[:, 1:]
        target_ids = target_ids[:, :-1]
        target_mask = target_mask[:, :-1]
        logits,_ = self.decode(target_ids, input_mask, target_mask, past_key_values)

        # make batch size and sentence length as one dimension
        logprobs = self.logSoftmax(logits)
        logprobs = logprobs.reshape([logprobs.shape[0] * logprobs.shape[1], -1])
        target_mask = target_mask.reshape([target_mask.shape[0] * target_mask.shape[1],])
        labels = labels.flatten()
        loss = -logprobs[torch.arange(logprobs.shape[0], device=labels.device), labels]
#         print(loss.shape, target_mask.shape)
        loss = torch.sum(loss * target_mask) / torch.sum(target_mask)
        return loss

In [3]:
train_data, val_data, test_data = get_eng_hi_dataset()
train_data = train_data[40000:]

Found cached dataset parquet (/root/.cache/huggingface/datasets/cfilt___parquet/cfilt--iitb-english-hindi-911387c6837f8b91/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
class ParallelCorpus(Dataset):
    def __init__(self, data, src_lang='en', tgt_lang='hi') -> None:
        super(ParallelCorpus, self).__init__()
        self.src = []
        self.tgt = []
        for pair in data:
            self.src.append(pair[src_lang])
            self.tgt.append(pair[tgt_lang])
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        return self.src[index], self.tgt[index]

train_pc = ParallelCorpus(train_data, src_lang='en', tgt_lang='hi')
test_pc = ParallelCorpus(test_data, src_lang='en', tgt_lang='hi')
val_pc = ParallelCorpus(val_data, src_lang='en', tgt_lang='hi')

In [5]:
len_prefix = 100
lr = 1e-4
batch_size = 4
num_epochs = 1
token_limit = (1023 - len_prefix) // 2

In [6]:
train_loader = DataLoader(dataset=train_pc, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(dataset=val_pc, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_pc, batch_size=1, shuffle=False)

In [7]:
tokenizer = MT5Tokenizer.from_pretrained("THUMT/mGPT")
model = GPT2LMHeadModel.from_pretrained("THUMT/mGPT")

MT_model = LLM(model).to(device)
optimizer = torch.optim.Adam(params=MT_model.parameters(),lr=lr, eps=1e-9)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'T5Tokenizer'.


In [None]:
import time

@torch.no_grad()
def validation():
    total_loss = 0
    for i, (src, tgt) in enumerate(val_loader):
        max_src_len = min(token_limit, max([len(s) for s in src]))
        max_tgt_len = min(token_limit, max([len(s) for s in tgt]))
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        total_loss += loss
    return total_loss / len(val_loader)
        

min_val_loss = 10000
PATH = 'saved_models/finetune.pt'
for epoch in range(num_epochs):
    print(f"------------------------EPOCH {epoch + 1}-------------------------------")
    t1 = time.time()
    for i, (src, tgt) in enumerate(train_loader):
        MT_model.zero_grad()
        
        max_src_len = min(token_limit, max([len(s) for s in src]))
        max_tgt_len = min(token_limit, max([len(s) for s in tgt]))
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
#         print(len(input_ids[0]))
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
#         print(len(input_ids[0]))
#         print(MT_model._model.config.max_position_embeddings)
#         print(tgt[0], target_ids[0])
        
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        loss.backward()
        optimizer.step()
        if (i+1)%500 == 0:
            t2 = time.time()
            val_loss = validation()
            if val_loss.item() < min_val_loss:
                torch.save(MT_model.state_dict(), PATH)
                min_val_loss = val_loss
            print(f'Step {i+1} | Val Loss: {val_loss.item():.5f}| Best val loss: {min_val_loss:.5f} | Time: {(t2-t1)/3600 : .4f} hrs')

------------------------EPOCH 1-------------------------------
Step 500 | Val Loss: 4.29350| Best val loss: 4.29350 | Time:  0.0406 hrs
Step 1000 | Val Loss: 3.50777| Best val loss: 3.50777 | Time:  0.0847 hrs
Step 1500 | Val Loss: 3.48872| Best val loss: 3.48872 | Time:  0.1283 hrs
Step 2000 | Val Loss: 3.70152| Best val loss: 3.48872 | Time:  0.1720 hrs
Step 2500 | Val Loss: 3.20066| Best val loss: 3.20066 | Time:  0.2121 hrs
Step 3000 | Val Loss: 3.70233| Best val loss: 3.20066 | Time:  0.2541 hrs
Step 3500 | Val Loss: 4.37849| Best val loss: 3.20066 | Time:  0.2943 hrs
Step 4000 | Val Loss: 3.22966| Best val loss: 3.20066 | Time:  0.3329 hrs
Step 4500 | Val Loss: 3.56074| Best val loss: 3.20066 | Time:  0.3717 hrs
Step 5000 | Val Loss: 3.14477| Best val loss: 3.14477 | Time:  0.4102 hrs
Step 5500 | Val Loss: 4.06344| Best val loss: 3.14477 | Time:  0.4510 hrs
Step 6000 | Val Loss: 3.07499| Best val loss: 3.07499 | Time:  0.4890 hrs
Step 6500 | Val Loss: 4.30031| Best val loss: 3.07