In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from multi_stage import LLM
from prep_data import get_eng_hi_dataset
from transformers import GPT2LMHeadModel, MT5Tokenizer

torch.manual_seed(42)
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=2)

In [2]:
class Prefix(nn.Module):
    def __init__(self, config, n_prefixes, len_prefix) -> None:
        super(Prefix, self).__init__()
        self.hidden_size = config.hidden_size
        self.embed_size = config.hidden_size
        self.n_layers = config.num_hidden_layers
        self.n_heads = config.num_attention_heads
        self.head_size = self.embed_size // self.n_heads 

        self.reparams = nn.Parameter(torch.ones(len_prefix))   
        self.prefixes = nn.Parameter(torch.empty((n_prefixes, 2 * self.n_layers, len_prefix, self.hidden_size)))

        with torch.no_grad():
            nn.init.xavier_uniform_(self.prefixes)
        
    def forward(self, batch_size):
        interm = torch.maximum(torch.ones(1, device=self.reparams.device), self.reparams)
        interm = interm.view([1, 1, interm.shape[0], 1])
        # print(interm.shape, self.prefixes[:, 0::2, :, :].shape)
        keys = (self.prefixes[:, 0::2, :, :] * interm).repeat([batch_size, 1, 1, 1, 1])
        values = (self.prefixes[:, 1::2, :, :] * interm).repeat([batch_size, 1, 1, 1, 1])
        keys = keys.view(keys.shape[0], keys.shape[1], keys.shape[2], keys.shape[3], self.n_heads, self.head_size)
        values = values.view(values.shape[0], values.shape[1], values.shape[2], values.shape[3], self.n_heads, self.head_size)

        # keys/values: (batch_size, n_prefixes ,n_layers, len_prefix, n_heads, head_size)

        return keys, values
    

class LLM(nn.Module):
    def __init__(self, model, len_prefix) -> None:
        super(LLM, self).__init__()
        self._model = model
        self.hidden_size = model.config.hidden_size
        self.embed_size = model.config.hidden_size
        self.n_layers = model.config.num_hidden_layers
        self.n_heads = model.config.num_attention_heads
        self.head_size = self.embed_size // self.n_heads
        self.len_prefix = len_prefix

        self.prefix = Prefix(model.config, 1, len_prefix)
        self.logSoftmax = nn.LogSoftmax(dim=2)
        self.logSoftmax_1 = nn.LogSoftmax(dim=1)
        self.nll = nn.NLLLoss()
        
        for param in self._model.parameters():
            param.requires_grad_(False)


    def encode(self, input_ids, input_mask):
        batch_size = input_ids.shape[0]
        len_prefix = self.len_prefix
        len_sent = input_ids.shape[1]
        prefix_mask = torch.ones([batch_size, len_prefix], device=input_mask.device)
        attn_mask = torch.cat([prefix_mask, input_mask], dim=1)
#         pos_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_mask.device)
#         pos_ids = pos_ids.unsqueeze(0).view(-1, input_ids.shape[-1])

        keys, values = self.prefix(batch_size)
        # batch_size, n_prefixes, n_layers, len_prefix, n_heads, head_size => batch_size, n_heads, len_prefix, head_size

        layer_prefix_list = []
        for i in range(self.n_layers):
            tup = (keys[:,0,i,:,:,:].permute(0,2,1,3), values[:,0,i,:,:,:].permute(0,2,1,3))
            layer_prefix_list.append(tup)

        outputs = self._model(input_ids, 
                              past_key_values=layer_prefix_list, 
                              attention_mask=attn_mask,  
                              use_cache=True)

        #RE-ENCODING
        layer_prefix_list = []
        attn_mask = torch.cat([input_mask, input_mask], dim=1)
#         pos_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_mask.device)
#         pos_ids = pos_ids.unsqueeze(0).view(-1, input_ids.shape[-1])

        #prepend re-encoding prefixes and exclude encoder prefixes in the re-encoding stage
        for i, (key, value) in enumerate(outputs.past_key_values):
            k = key[:,:,len_prefix:,:]
            v = value[:,:,len_prefix:,:]
            layer_prefix_list.append((k, v))
        
        outputs = self._model(input_ids, 
                              past_key_values=layer_prefix_list, 
                              attention_mask=attn_mask, 
                              use_cache=True)
        layer_prefix_list = []

        #prepare past key values for decoding stage
        for i, (key, value) in enumerate(outputs.past_key_values):
            k = key[:,:,len_sent:,:]
            v = value[:,:,len_sent:,:]
            layer_prefix_list.append((k, v))

        return layer_prefix_list
    
    def decode(self, target_ids, input_mask, target_mask, past_key_values, mode='train'):
        batch_size = target_ids.shape[0]
        len_prefix = self.len_prefix
        prefix_mask = torch.ones([batch_size, len_prefix], device=input_mask.device)

        #only attend to decode stage prefixes and re-encoding stage hidden states
        attn_mask = torch.cat([input_mask, target_mask], dim=1)
#         pos_ids = torch.arange(0, target_ids.shape[-1], dtype=torch.long, device=input_mask.device)
#         pos_ids = pos_ids.unsqueeze(0).view(-1, target_ids.shape[-1])

        outputs = self._model(target_ids, 
                              past_key_values=past_key_values, 
                              attention_mask=attn_mask,  
                              use_cache=True)
        return outputs.logits, outputs.past_key_values
    
    def forward(self, input_ids, input_mask, target_ids, target_mask):
        past_key_values = self.encode(input_ids, input_mask)
        labels = target_ids[:, 1:]
        target_ids = target_ids[:, :-1]
        target_mask = target_mask[:, :-1]
        logits,_ = self.decode(target_ids, input_mask, target_mask, past_key_values)

        # make batch size and sentence length as one dimension
        logprobs = self.logSoftmax(logits)
        logprobs = logprobs.reshape([logprobs.shape[0] * logprobs.shape[1], -1])
        target_mask = target_mask.reshape([target_mask.shape[0] * target_mask.shape[1],])
        labels = labels.flatten()
        loss = -logprobs[torch.arange(logprobs.shape[0], device=labels.device), labels]
#         print(loss.shape, target_mask.shape)
        loss = torch.sum(loss * target_mask) / torch.sum(target_mask)
        return loss

In [3]:
def read_data(PATH):
    dataset = []
    f_en = open(PATH + 'filtered.en', 'r')
    for line in f_en.readlines():
        line = line.strip('\n')
        entry = {'en': line}
        dataset.append(entry)
    f_en.close()
    
    f_hi = open(PATH + 'filtered.hi', 'r')
    for i, line in enumerate(f_hi.readlines()):
        line = line.strip('\n')
        dataset[i]['hi'] = line
    f_hi.close()
    return dataset

val_data, test_data = get_eng_hi_dataset()
train_data = read_data('filtered_data/')
train_data = train_data[:100000]

Found cached dataset parquet (/root/.cache/huggingface/datasets/cfilt___parquet/cfilt--iitb-english-hindi-911387c6837f8b91/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
class ParallelCorpus(Dataset):
    def __init__(self, data, src_lang='en', tgt_lang='hi') -> None:
        super(ParallelCorpus, self).__init__()
        self.src = []
        self.tgt = []
        for pair in data:
            self.src.append(pair[src_lang])
            self.tgt.append(pair[tgt_lang])
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        return self.src[index], self.tgt[index]

train_pc = ParallelCorpus(train_data, src_lang='en', tgt_lang='hi')
test_pc = ParallelCorpus(test_data, src_lang='en', tgt_lang='hi')
val_pc = ParallelCorpus(val_data, src_lang='en', tgt_lang='hi')

In [5]:
len_prefix = 100
lr = 1e-4
batch_size = 4
num_epochs = 2
token_limit = ((1023 - len_prefix) // 2) - 3  #to accomodate extra one token if max_len=1

In [6]:
train_loader = DataLoader(dataset=train_pc, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(dataset=val_pc, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_pc, batch_size=1, shuffle=False)

In [8]:
tokenizer = MT5Tokenizer.from_pretrained("THUMT/mGPT")
model = GPT2LMHeadModel.from_pretrained("THUMT/mGPT")

MT_model = LLM(model, len_prefix).to(device)
optimizer = torch.optim.Adam(params=MT_model.parameters(),lr=lr, eps=1e-9)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'T5Tokenizer'.


In [None]:
import time

@torch.no_grad()
def validation():
    total_loss = 0
    for i, (src, tgt) in enumerate(val_loader):
        max_src_len = min(token_limit, max([len(s) for s in src])) + 1   #need this to accomodate max_len = 1
        max_tgt_len = min(token_limit, max([len(s) for s in tgt])) + 1
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        total_loss += loss
    return total_loss / len(val_loader)
        

min_val_loss = 10000
PATH = 'saved_models/finetune.pt'
for epoch in range(num_epochs):
    print(f"------------------------EPOCH {epoch + 1}-------------------------------")
    t1 = time.time()
    for i, (src, tgt) in enumerate(train_loader):
        MT_model.zero_grad()
        
        max_src_len = min(token_limit, max([len(s) for s in src])) + 1   #need this to accomodate max_len = 1
        max_tgt_len = min(token_limit, max([len(s) for s in tgt])) + 1
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
#         print(len(input_ids[0]))
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
#         print(len(input_ids[0]))
#         print(MT_model._model.config.max_position_embeddings)
#         print(tgt[0], target_ids[0])
        
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        loss.backward()
        optimizer.step()
        if (i+1)%500 == 0:
            t2 = time.time()
            val_loss = validation()
            if val_loss.item() < min_val_loss:
                torch.save(MT_model.state_dict(), PATH)
                min_val_loss = val_loss
            print(f'Step {i+1} | Val Loss: {val_loss.item():.5f}| Best val loss: {min_val_loss:.5f} | Time: {(t2-t1)/3600 : .4f} hrs')

------------------------EPOCH 1-------------------------------
