# nl_tp_m2m


### packages

In [1]:
import random
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from sacrebleu.metrics import BLEU
from tqdm.auto import tqdm
from torchmetrics.functional.text.rouge import rouge_score
from pprint import pprint

### Parameters

In [2]:
model_checkpoint = "facebook/m2m100_418M"
max_input_length = 128
max_target_length = 128
batch_size = 8
learning_rate = 1e-5
epoch_num = 5
src_lang = 'en'
tgt_lang = 'en'
train_set_size = 8213 #0.8
valid_set_size = 1027 #0.1
test_set_size = 1027 #0.1
train_path = '/home/xinming/Text2ESQ/input/v2/easy_train.csv'
test_path = '/home/xinming/Text2ESQ/input/v2/easy_test.csv'
valid_path = '/home/xinming/Text2ESQ/input/v2/easy_valid.csv'

### Load dataset, model, tokenizer

In [5]:
bleu = BLEU()
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, src_lang=src_lang, tgt_lang=tgt_lang)

In [6]:
class NLTP_Dataset(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        df = pd.read_csv(data_file)
        for idx, x in enumerate(df.index):
            Data[idx] = {'en': df['tp'][x].strip(), 'zh_en': df['nl'][x], 'vaers_key':x}
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def preprocess(data):
    inputs = [x['zh_en'] for x in data]
    targets = [x['en'] for x in data]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=max_input_length, truncation=True, padding=True, return_tensors='pt')
    
    end_token_index = torch.where(model_inputs['labels'] == tokenizer.eos_token_id)[1]
    for idx, end_idx in enumerate(end_token_index):
        model_inputs['labels'][idx][end_idx+1:] = -100
    return model_inputs


In [7]:
train_dataloader = DataLoader(NLTP_Dataset(train_path), batch_size=batch_size, shuffle=True, collate_fn=preprocess)
test_dataloader = DataLoader(NLTP_Dataset(test_path), batch_size=batch_size, shuffle=False, collate_fn=preprocess)
valid_dataloader = DataLoader(NLTP_Dataset(valid_path), batch_size=batch_size, shuffle=False, collate_fn=preprocess)

### train

In [9]:
def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()

    return total_loss

def test_loop(dataloader, model, epoch):
    sources, preds, labels =[], [], []
    
    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
            ).cpu().numpy()
        label_tokens = batch_data["labels"].cpu().numpy()
        
        decoded_sources = tokenizer.batch_decode(
            batch_data["input_ids"].cpu().numpy(), 
            skip_special_tokens=True, 
            use_source_tokenizer=True
        )

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        sources += [source.strip() for source in decoded_sources]
        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"BLEU: {bleu_score:>0.2f}\n")

    results = []
    print("saving valid results")
    for source, pred, label in zip(sources, preds, labels):
        results.append({
            "sentence": source, 
            "prediction": pred, 
            "translation": label[0]
        })
    f = open(f'/home/xinming/Text2ESQ/out_files/easy_nltp_weights/easy_nltp_valid_pred_{epoch+1}.json', 'wt')
    json.dump(results, f)
    f.close()   
    
    return bleu_score

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_bleu = 0.
for t in tqdm(range(epoch_num)):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_bleu = test_loop(valid_dataloader, model, t)
    if valid_bleu > best_bleu:
        best_bleu = valid_bleu
        print('saving new weights...\n')
        torch.save(
            model.state_dict(), 
            f'/home/xinming/Text2ESQ/out_files/easy_nltp_weights/NLTP_epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin'
        )
print("Done!")


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1/5
-------------------------------


  0%|          | 0/163 [00:00<?, ?it/s]

BLEU: 100.00

saving valid results
saving new weights...

Epoch 2/5
-------------------------------


  0%|          | 0/163 [00:00<?, ?it/s]

BLEU: 100.00

saving valid results
Epoch 3/5
-------------------------------


  0%|          | 0/163 [00:00<?, ?it/s]

BLEU: 100.00

saving valid results
Epoch 4/5
-------------------------------


  0%|          | 0/163 [00:00<?, ?it/s]

BLEU: 100.00

saving valid results
Epoch 5/5
-------------------------------


  0%|          | 0/163 [00:00<?, ?it/s]

BLEU: 100.00

saving valid results
Done!


## test

In [11]:
model.load_state_dict(torch.load('/home/xinming/Text2ESQ/out_files/easy_nltp_weights/NLTP_epoch_1_valid_bleu_100.00_model_weights.bin'))

model.eval()
with torch.no_grad():
    sources, preds, labels = [], [], []
    for batch_data in tqdm(test_dataloader):
        batch_data = batch_data.to(device)
        generated_tokens = model.generate(
            batch_data["input_ids"],
            attention_mask=batch_data["attention_mask"],
            max_length=max_target_length,
        ).cpu().numpy()
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_sources = tokenizer.batch_decode(
            batch_data["input_ids"].cpu().numpy(), 
            skip_special_tokens=True, 
            use_source_tokenizer=True
        )
        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        sources += [source.strip() for source in decoded_sources]
        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"Test BLEU: {bleu_score:>0.2f}\n")
    results = []
    print('saving predicted results...')
    for source, pred, label in zip(sources, preds, labels):
        results.append({
            "sentence": source, 
            "prediction": pred, 
            "translation": label[0]
        })
    f = open('/home/xinming/Text2ESQ/out_files/easy_nltp_weights/nltp_m2m_test_pred.json', 'wt')
    json.dump(results, f)
    f.close()

  0%|          | 0/163 [00:00<?, ?it/s]

Test BLEU: 86.94

saving predicted results...


# Rouge

In [12]:
# Valid rouge
file_path = '/home/xinming/Text2ESQ/out_files/easy_nltp_weights/easy_nltp_valid_pred_5.json'
f = open(file_path, 'r')
data = json.load(f)

preds = [x['prediction'] for x in data]
target = [x['translation'] for x in data]

pprint(rouge_score(preds, target))

{'rouge1_fmeasure': tensor(0.9208),
 'rouge1_precision': tensor(0.9285),
 'rouge1_recall': tensor(0.9158),
 'rouge2_fmeasure': tensor(0.8724),
 'rouge2_precision': tensor(0.8796),
 'rouge2_recall': tensor(0.8676),
 'rougeL_fmeasure': tensor(0.9184),
 'rougeL_precision': tensor(0.9260),
 'rougeL_recall': tensor(0.9134),
 'rougeLsum_fmeasure': tensor(0.9185),
 'rougeLsum_precision': tensor(0.9261),
 'rougeLsum_recall': tensor(0.9135)}


In [13]:
# Test rouge
file_path = '/home/xinming/Text2ESQ/out_files/easy_nltp_weights/nltp_m2m_test_pred.json'
f = open(file_path, 'r')
data = json.load(f)

preds = [x['prediction'] for x in data]
target = [x['translation'] for x in data]

pprint(rouge_score(preds, target))

{'rouge1_fmeasure': tensor(0.9028),
 'rouge1_precision': tensor(0.9129),
 'rouge1_recall': tensor(0.8969),
 'rouge2_fmeasure': tensor(0.8428),
 'rouge2_precision': tensor(0.8518),
 'rouge2_recall': tensor(0.8374),
 'rougeL_fmeasure': tensor(0.8985),
 'rougeL_precision': tensor(0.9084),
 'rougeL_recall': tensor(0.8924),
 'rougeLsum_fmeasure': tensor(0.8986),
 'rougeLsum_precision': tensor(0.9085),
 'rougeLsum_recall': tensor(0.8925)}
