In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from prep_data import get_eng_hi_dataset
from transformers import GPT2LMHeadModel, MT5Tokenizer

torch.manual_seed(42)

device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=6)

In [2]:
class LLM(nn.Module):
    def __init__(self, model) -> None:
        super(LLM, self).__init__()
        self._model = model
        self.hidden_size = model.config.hidden_size
        self.embed_size = model.config.hidden_size
        self.n_layers = model.config.num_hidden_layers
        self.n_heads = model.config.num_attention_heads
        self.head_size = self.embed_size // self.n_heads

        self.logSoftmax = nn.LogSoftmax(dim=2)
        self.logSoftmax_1 = nn.LogSoftmax(dim=1)
        self.nll = nn.NLLLoss()


    def encode(self, input_ids, input_mask):
        batch_size = input_ids.shape[0]
        len_sent = input_ids.shape[1]
        attn_mask = input_mask

        outputs = self._model(input_ids, 
                              attention_mask=attn_mask,  
                              use_cache=True)

        #RE-ENCODING
        layer_prefix_list = []
        attn_mask = torch.cat([input_mask, input_mask], dim=1)

        outputs = self._model(input_ids, 
                              past_key_values=outputs.past_key_values, 
                              attention_mask=attn_mask, 
                              use_cache=True)
        
        past_key_values = []
        for (key, value) in outputs.past_key_values:
            k = key[:,:,len_sent:,:]
            v = value[:,:,len_sent:,:]
            past_key_values.append((k, v))

        return past_key_values
    
    def decode(self, target_ids, input_mask, target_mask, past_key_values, mode='train'):
        batch_size = target_ids.shape[0]
        attn_mask = torch.cat([input_mask, target_mask], dim=1)

        outputs = self._model(target_ids, 
                              past_key_values=past_key_values, 
                              attention_mask=attn_mask,  
                              use_cache=True)
        return outputs.logits, outputs.past_key_values
    
    def forward(self, input_ids, input_mask, target_ids, target_mask):
        past_key_values = self.encode(input_ids, input_mask)
        labels = target_ids[:, 1:]
        target_ids = target_ids[:, :-1]
        target_mask = target_mask[:, :-1]
        logits,_ = self.decode(target_ids, input_mask, target_mask, past_key_values)

        # make batch size and sentence length as one dimension
        logprobs = self.logSoftmax(logits)
        logprobs = logprobs.reshape([logprobs.shape[0] * logprobs.shape[1], -1])
        target_mask = target_mask.reshape([target_mask.shape[0] * target_mask.shape[1],])
        labels = labels.flatten()
        loss = -logprobs[torch.arange(logprobs.shape[0], device=labels.device), labels]
#         print(loss.shape, target_mask.shape)
        loss = torch.sum(loss * target_mask) / torch.sum(target_mask)
        return loss

In [3]:
def read_data(PATH):
    dataset = []
    f_en = open(PATH + 'filtered.en', 'r')
    for line in f_en.readlines():
        line = line.strip('\n')
        entry = {'en': line}
        dataset.append(entry)
    f_en.close()
    
    f_hi = open(PATH + 'filtered.hi', 'r')
    for i, line in enumerate(f_hi.readlines()):
        line = line.strip('\n')
        dataset[i]['hi'] = line
    f_hi.close()
    return dataset

val_data, test_data = get_eng_hi_dataset()
train_data = read_data('filtered_data/')
train_data = train_data[:100000]

Found cached dataset parquet (/root/.cache/huggingface/datasets/cfilt___parquet/cfilt--iitb-english-hindi-911387c6837f8b91/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
class ParallelCorpus(Dataset):
    def __init__(self, data, src_lang='en', tgt_lang='hi') -> None:
        super(ParallelCorpus, self).__init__()
        self.src = []
        self.tgt = []
        for pair in data:
            self.src.append(pair[src_lang])
            self.tgt.append(pair[tgt_lang])
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        return self.src[index], self.tgt[index]

train_pc = ParallelCorpus(train_data, src_lang='en', tgt_lang='hi')
test_pc = ParallelCorpus(test_data, src_lang='en', tgt_lang='hi')
val_pc = ParallelCorpus(val_data, src_lang='en', tgt_lang='hi')

In [5]:
len_prefix = 100
lr = 1e-4
batch_size = 4
num_epochs = 2
token_limit = ((1023 - len_prefix) // 2) - 3  #to accomodate extra one token if max_len=1

In [6]:
train_loader = DataLoader(dataset=train_pc, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(dataset=val_pc, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_pc, batch_size=1, shuffle=False)

In [7]:
tokenizer = MT5Tokenizer.from_pretrained("THUMT/mGPT")
model = GPT2LMHeadModel.from_pretrained("THUMT/mGPT")

MT_model = LLM(model).to(device)
optimizer = torch.optim.Adam(params=MT_model.parameters(),lr=lr, eps=1e-9)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'T5Tokenizer'.


In [8]:
import time

@torch.no_grad()
def validation():
    total_loss = 0
    for i, (src, tgt) in enumerate(val_loader):
        max_src_len = min(token_limit, max([len(s) for s in src])) + 1   #need this to accomodate max_len = 1
        max_tgt_len = min(token_limit, max([len(s) for s in tgt])) + 1
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        total_loss += loss
    return total_loss / len(val_loader)
        

min_val_loss = 10000
PATH = 'saved_models/finetune.pt'
for epoch in range(num_epochs):
    print(f"------------------------EPOCH {epoch + 1}-------------------------------")
    t1 = time.time()
    for i, (src, tgt) in enumerate(train_loader):
        MT_model.zero_grad()
        
        max_src_len = min(token_limit, max([len(s) for s in src])) + 1   #need this to accomodate max_len = 1
        max_tgt_len = min(token_limit, max([len(s) for s in tgt])) + 1
        inputs = tokenizer(src, padding='max_length', truncation=True, max_length=max_src_len)
        targets = tokenizer(tgt, padding='max_length', truncation=True, max_length=max_tgt_len)
        input_ids, input_masks = inputs['input_ids'], inputs['attention_mask']
        target_ids, target_masks = targets['input_ids'], targets['attention_mask']
#         print(len(input_ids[0]))
        for j in range(len(target_ids)):
            target_ids[j].insert(0, 1)
            target_masks[j].insert(0, 1)
#         print(len(input_ids[0]))
#         print(MT_model._model.config.max_position_embeddings)
#         print(tgt[0], target_ids[0])
        
        input_ids, input_masks = torch.tensor(input_ids).to(device), torch.tensor(input_masks).to(device)
        target_ids, target_masks = torch.tensor(target_ids).to(device), torch.tensor(target_masks).to(device)
        loss = MT_model(input_ids, input_masks, target_ids, target_masks)
        loss.backward()
        optimizer.step()
        if (i+1)%500 == 0:
            t2 = time.time()
            val_loss = validation()
            if val_loss.item() < min_val_loss:
                torch.save(MT_model.state_dict(), PATH)
                min_val_loss = val_loss
            print(f'Step {i+1} | Val Loss: {val_loss.item():.5f}| Best val loss: {min_val_loss:.5f} | Time: {(t2-t1)/3600 : .4f} hrs')

------------------------EPOCH 1-------------------------------
Step 500 | Val Loss: 3.35807| Best val loss: 3.35807 | Time:  0.0575 hrs
Step 1000 | Val Loss: 3.44138| Best val loss: 3.35807 | Time:  0.1227 hrs
Step 1500 | Val Loss: 3.28490| Best val loss: 3.28490 | Time:  0.1880 hrs
Step 2000 | Val Loss: 3.36685| Best val loss: 3.28490 | Time:  0.2560 hrs
Step 2500 | Val Loss: 3.20250| Best val loss: 3.20250 | Time:  0.3220 hrs
Step 3000 | Val Loss: 3.32183| Best val loss: 3.20250 | Time:  0.3891 hrs
Step 3500 | Val Loss: 3.18532| Best val loss: 3.18532 | Time:  0.4561 hrs
Step 4000 | Val Loss: 3.06976| Best val loss: 3.06976 | Time:  0.5239 hrs
Step 4500 | Val Loss: 2.97176| Best val loss: 2.97176 | Time:  0.5931 hrs
Step 5000 | Val Loss: 3.05036| Best val loss: 2.97176 | Time:  0.6623 hrs
Step 5500 | Val Loss: 2.95531| Best val loss: 2.95531 | Time:  0.7316 hrs
Step 6000 | Val Loss: 2.93756| Best val loss: 2.93756 | Time:  0.8049 hrs
Step 6500 | Val Loss: 2.93002| Best val loss: 2.93

In [8]:
MT_model.load_state_dict(torch.load('saved_models/finetune.pt'))

<All keys matched successfully>

In [21]:
@torch.no_grad()
def greedy_translate(model, device, tokenizer, input_sent):
    tok_output = tokenizer(input_sent, truncation=True, max_length=token_limit)
    input_ids = tok_output['input_ids']
    input_mask = tok_output['attention_mask']
    input_ids = torch.tensor(input_ids, device=device)
    input_mask = torch.tensor(input_mask, device=device)
    target_mask = torch.ones([1, 1], device=device)
    past_key_values = model.encode(input_ids, input_mask)
#     print(past_key_values[0][0].shape)
    start = [1]
    gen = []
    curr_token = None
    while curr_token != 1 and len(gen) < token_limit:
        tgt = torch.tensor(start, device=device)    
        logits, past_key_values = model.decode(tgt, input_mask, target_mask, past_key_values)
#         print(past_key_values[0][0].shape)
        logits = model.logSoftmax(logits.unsqueeze(0)).squeeze(0)
        value, index = torch.max(logits, dim=1)
        curr_token = index[0].item()
        gen.append(curr_token)
        start = [curr_token]
#         print(curr_token, value.item())
        target_mask = torch.cat([target_mask, torch.ones([1, 1], device=device)], dim=1)
    output_sent = tokenizer.decode(gen)
    return output_sent

sent = ['I am going to school.']
greedy_translate(MT_model, device, tokenizer, sent)

'मैं स्कूल जा रही हूँ।</s>'

In [10]:
import sacrebleu

candidates = []
references = []
for i, (src, tgt) in enumerate(test_loader):
    references.append(tgt[0])
    candidate = greedy_translate(MT_model, device, tokenizer, [src[0]])
    candidates.append(candidate[:-4])
    if (i+1)%10 == 0:
        print(f'{i+1} sentences processed.')

bleu = sacrebleu.corpus_bleu(candidates, [references])
print(f'BLEU score = {bleu}')

# ref_file = 'path/to/reference/translations.txt'
# with open(ref_file, 'r') as f:
#     refs = [line.strip() for line in f.readlines()]

# hyp_file = 'path/to/candidate/translations.txt'
# with open(hyp_file, 'r') as f:
#     hyps = [line.strip() for line in f.readlines()]

# bleu = sacrebleu.corpus_bleu(hyps, [refs])
# bleu

10 sentences processed.
20 sentences processed.
30 sentences processed.
40 sentences processed.
50 sentences processed.
60 sentences processed.
70 sentences processed.
80 sentences processed.
90 sentences processed.
100 sentences processed.
110 sentences processed.
120 sentences processed.
130 sentences processed.
140 sentences processed.
150 sentences processed.
160 sentences processed.
170 sentences processed.
180 sentences processed.
190 sentences processed.
200 sentences processed.
210 sentences processed.
220 sentences processed.
230 sentences processed.
240 sentences processed.
250 sentences processed.
260 sentences processed.
270 sentences processed.
280 sentences processed.
290 sentences processed.
300 sentences processed.
310 sentences processed.
320 sentences processed.
330 sentences processed.
340 sentences processed.
350 sentences processed.
360 sentences processed.
370 sentences processed.
380 sentences processed.
390 sentences processed.
400 sentences processed.
410 sente

In [19]:
chrF3 = sacrebleu.corpus_chrf(candidates, [references], char_order=3, word_order=2)
print(f'{bleu}\n{chrF3}')

BLEU = 6.90 33.1/11.1/4.1/1.5 (BP = 0.991 ratio = 0.991 hyp_len = 60289 ref_len = 60821)
chrF2++ = 35.38
