In [1]:
import json

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
import wandb

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
device

'cuda'

In [6]:
class MyDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.add_tokens(['<H>', '<R>', '<T>'])    # try with special tokens later
        self.data_path = data_path
        self.data = []
        with open(self.data_path, 'r') as f:
            data = json.load(f)
            for entry in data['entries']:
                for value in entry.values():
                    triple_set = []
                    for triple in value['modifiedtripleset']:
                        triple_set.append('<H>')
                        triple_set.append(triple['subject'])
                        triple_set.append('<R>')
                        triple_set.append(triple['property'])
                        triple_set.append('<T>')
                        triple_set.append(triple['object'])
                    for lex in value['lexicalisations']:
                        self.data.append((' '.join(triple_set), lex['lex']))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        triple_set = self.data[idx][0]
        lex = self.data[idx][1]
        source = self.tokenizer(triple_set, return_tensors='pt', padding='max_length', truncation=True)
        target = self.tokenizer(lex, return_tensors='pt', padding='max_length', truncation=True)

        return {'source_input_ids': source['input_ids'].squeeze(), 'source_attention_mask': source['attention_mask'].squeeze(), 'target_input_ids': target['input_ids'].squeeze(), 'target_attention_mask': target['attention_mask'].squeeze()}

In [7]:
def train(model, dataloader, tokenizer, optimizer):
    model.train()
    for batch in dataloader:
        source_input_ids = batch['source_input_ids'].to(device)
        source_attention_mask = batch['source_attention_mask'].to(device)
        labels = batch['target_input_ids'].to(device)
        labels[labels == tokenizer.pad_token_id] = -100
        
        outputs = model(source_input_ids, attention_mask=source_attention_mask, labels=labels)
        
        loss = outputs[0]
        wandb.log({'loss': loss})
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [8]:
def validate(model, dataloader, tokenizer):
    actual, preds = [], []
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            source_input_ids = batch['source_input_ids'].to(device)
            source_attention_mask = batch['source_attention_mask'].to(device)
            outputs = model.generate(source_input_ids, attention_mask=source_attention_mask)
            preds.append([tokenizer.decode(output, skip_special_tokens=True) for output in outputs])
            actual.append([tokenizer.decode(batch['target_input_ids'][i], skip_special_tokens=True) for i in range(batch['target_input_ids'].shape[0])])
    return actual, preds

In [9]:
if __name__=='__main__':
    wandb.init(project='t5-webnlg')
    wandb.WANDB_NOTEBOOK_NAME = 't5-webnlg'
    config = wandb.config
    config.epochs = 1
    
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    
    train_data_path = './train.json'
    train_dataset = MyDataset(train_data_path, tokenizer)
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

    val_data_path = './dev.json'
    val_dataset = MyDataset(val_data_path, tokenizer)
    val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    model = T5ForConditionalGeneration.from_pretrained('t5-small')
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    wandb.watch(model, log='all')

    for epoch in range(config.epochs):
        train(model, train_dataloader, tokenizer, optimizer)

    actual, preds = validate(model, val_dataloader, tokenizer)
    json.dump({'actual': actual, 'preds': preds}, open('preds.json', 'w'))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshivprasad[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [21]:
with open('preds.json', 'r') as f:
    data = json.load(f)
    actual = data['actual']
    preds = data['preds']
    for id, (a, p) in enumerate(zip(actual, preds)):
        if id==301:
            for i in range(8):
                print(a[i], '---', p[i])
                print('*'*20)
            break

Bhajji comes from the country of India, where two of the leaders are, T. S. Thakur and Narendra Modi. --- Bhajji is a leader in India, which is led by Narendra Mod
********************
Bhajji originates from India, where two of the leaders are Narendra Modi and T.S. Thakur. --- Bhajji is a leader in India, which is led by Narendra Mod
********************
The dish bhajji originates in India where T.S. Thakur and Narendra Modi are leaders. --- Bhajji is a leader in India, which is led by Narendra Mod
********************
Bhajji originates from the Karnataka region and the main ingredients are vegetables and gram flour. --- Bhajji is a main ingredient in Karnataka. It is
********************
The main ingredients in Bhajji are gram flour and vegetables, this comes from the Karnataka region. --- Bhajji is a main ingredient in Karnataka. It is
********************
Bhajji are found in the region of Karnataka, its main ingredients are gram flour and vegetables. --- Bhajji is a main ingredient