In [1]:
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup

In [2]:
device=torch.device('cuda:1')

In [3]:
tokenizer=GPT2Tokenizer.from_pretrained('gpt2-medium')
model=GPT2LMHeadModel.from_pretrained('gpt2-medium')

In [4]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

Embedding(50258, 1024)

In [5]:
print("bos_token:", tokenizer.bos_token)
print("eos_token:", tokenizer.eos_token)
print("pad_token_id:", tokenizer.pad_token_id)

bos_token: <|endoftext|>
eos_token: <|endoftext|>
pad_token_id: 50257


In [6]:
with open('./dataset/webnlg/train.json', 'r') as f:
    dict_train=json.load(f)
    f.close()

In [7]:
data_triple=[]
data_text=[]

for index, data in enumerate(dict_train['entries']):
    triples=data[str(index+1)]['modifiedtripleset']
    triple_proc=""
    for triple in triples:
        subj, prop, obj=triple['subject'], triple['property'], triple['object']
        triple_proc+="| {} : {} : {} ".format(subj, prop, obj)
        
    texts=data[str(index+1)]['lexicalisations']
    for text in texts:
        if text['comment']!="good": continue
            
        data_triple.append(triple_proc)
        data_text.append(text['lex'])
        
print(len(data_triple), "Triples")
print(len(data_text), "Texts")

18025 Triples
18025 Texts


In [8]:
batch_size=1
accumulation_steps=6
epochs=10
lr=1e-5

In [9]:
class D2TDataset(Dataset):
    def __init__(self, tokenizer, data_triple, data_text):
        self.data=[]
        self.label=[]
        
        for index, triple in enumerate(data_triple):
            data=tokenizer.encode(triple+tokenizer.bos_token+data_text[index]+tokenizer.eos_token)
            self.data.append(data)
            
            label=tokenizer.encode(triple+tokenizer.bos_token+data_text[index]+tokenizer.eos_token)
            sep=label.index(tokenizer.bos_token_id)+1
            label[:sep]=[-100]*sep
            self.label.append(label)
            
        print(len(self.data), "Data")
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]
    
    def __len__(self):
        return len(self.data)

In [10]:
def collate_fn(batch):
    max_len=0
    for data, _ in batch:
        if len(data)>max_len: max_len=len(data)
            
    datas=[]
    labels=[]
    for data, label in batch:
        data.extend([tokenizer.pad_token_id]*(max_len-len(data)))
        datas.append(data)
        
        label.extend([tokenizer.pad_token_id]*(max_len-len(label)))
        labels.append(label)
        
    return torch.tensor(datas), torch.tensor(labels)

In [11]:
dataset=D2TDataset(tokenizer=tokenizer, data_triple=data_triple, data_text=data_text)
dataloader=DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

18025 Data


In [12]:
optimizer=AdamW(model.parameters(), lr=lr)
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=1000,
    num_training_steps=int(epochs*len(dataset)/(accumulation_steps*batch_size))
)

model.to(device)
model.train()

for epoch in range(epochs):
    loss_total=0
    optimizer.zero_grad()
    for step, (data, label) in enumerate(dataloader):
        data=data.to(device)
        label=label.to(device)
        
        outputs=model(data, labels=label)
        
        loss=outputs[0]/accumulation_steps
        loss.backward()
        
        loss_total+=loss.item()
        
        if (step+1)%accumulation_steps==0:
            if (step+1)%(300*accumulation_steps)==0:
                print(f'epoch {epoch+1} step {(step+1)/accumulation_steps} loss {loss_total:.4f}')
            loss_total=0
            
            optimizer.step()
            scheduler.step()
            
            optimizer.zero_grad()

model.eval()
model.to(torch.device('cpu'))

torch.save(model, './model/'+f'finetuned_batch{int(accumulation_steps*batch_size)}_epoch{epochs}_lr{lr}.pt')

epoch 1 step 300.0 loss 3.1598
epoch 1 step 600.0 loss 1.8543
epoch 1 step 900.0 loss 1.5883
epoch 1 step 1200.0 loss 1.5418
epoch 1 step 1500.0 loss 1.7015
epoch 1 step 1800.0 loss 1.1925
epoch 1 step 2100.0 loss 1.1479
epoch 1 step 2400.0 loss 1.1715
epoch 1 step 2700.0 loss 1.0070
epoch 1 step 3000.0 loss 0.8128
epoch 2 step 300.0 loss 0.7740
epoch 2 step 600.0 loss 1.0153
epoch 2 step 900.0 loss 0.8562
epoch 2 step 1200.0 loss 0.9356
epoch 2 step 1500.0 loss 0.9994
epoch 2 step 1800.0 loss 0.5988
epoch 2 step 2100.0 loss 0.7066
epoch 2 step 2400.0 loss 0.8386
epoch 2 step 2700.0 loss 0.8179
epoch 2 step 3000.0 loss 0.7492
epoch 3 step 300.0 loss 0.5637
epoch 3 step 600.0 loss 1.1142
epoch 3 step 900.0 loss 0.8546
epoch 3 step 1200.0 loss 0.9596
epoch 3 step 1500.0 loss 0.6909
epoch 3 step 1800.0 loss 0.4799
epoch 3 step 2100.0 loss 0.7348
epoch 3 step 2400.0 loss 0.9001
epoch 3 step 2700.0 loss 0.8188
epoch 3 step 3000.0 loss 0.7017
epoch 4 step 300.0 loss 0.7981
epoch 4 step 600.0