In [1]:
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup

In [2]:
# Root Path
root_path='/root/research/Graph-To-Text/'

# Device: (Single) GPU
device=torch.device('cuda:2')

# Hyperparams
ctrlseqlen=2 # Control Prefix Length
preseqlen=10 # General Prefix Length
hidden_dim=768 # Reparam Hidden Dimension
batch_size=5
accumulation_steps=1
epochs=5
lr=7e-5

In [3]:
# Pre-Trained Tokenizer, LM
tokenizer=GPT2Tokenizer.from_pretrained('gpt2-large')
pretrained=GPT2LMHeadModel.from_pretrained('gpt2-large').to(device)

# Add PAD Token: [PAD]
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
pretrained.resize_token_embeddings(len(tokenizer))

# Freeze LM
for param in pretrained.parameters():
    param.requires_grad=False

In [4]:
# Special Tokens
print("bos_token:", tokenizer.bos_token)
print("eos_token:", tokenizer.eos_token)
print("pad_token_id:", tokenizer.pad_token_id)

bos_token: <|endoftext|>
eos_token: <|endoftext|>
pad_token_id: 50257


In [5]:
def process_webnlg(dicts):
    """
    Process WebNLG Dataset
    """
    triples=[]
    texts=[]
    categories=[]
    
    for index, dict_ in enumerate(dicts['entries']):
        
        data=dict_[str(index+1)]
        
        # Triple Data
        triple_proc=''
        for triple in data['modifiedtripleset']:
            subj, prop, obj=triple['subject'], triple['property'], triple['object']
            triple_proc+='| {} : {} : {} '.format(subj, prop, obj)
            
        # Text Data
        for text in data['lexicalisations']:
            if text['comment']!='good': continue
                
            triples.append(triple_proc)
            texts.append(text['lex'])
            categories.append(data['category'])
            
    print(len(triples), "data")
    
    return triples, texts, categories

In [6]:
with open(root_path+'dataset/webnlg/train.json', 'r') as f:
    dict_train=json.load(f)
    f.close()
    
# Process Train Set
triples_train, texts_train, categories_train=process_webnlg(dict_train)

with open(root_path+'dataset/webnlg/dev.json', 'r') as f:
    dict_dev=json.load(f)
    f.close()
    
# Process Dev Set
triples_dev, texts_dev, categories_dev=process_webnlg(dict_dev)

18025 data
2258 data


In [7]:
class WebNLGDataset(Dataset):
    """
    PyTorch Dataset Class: WebNLG Dataset
    """
    def __init__(self, tokenizer, triples, texts, categories):
        self.data=[]
        self.label=[]
        self.category=categories
        
        for index, triple in enumerate(triples):
            data=tokenizer.encode(triple+tokenizer.bos_token+texts[index]+tokenizer.eos_token)
            self.data.append(data)
            
            label=tokenizer.encode(triple+tokenizer.bos_token+texts[index]+tokenizer.eos_token)
            sep=label.index(tokenizer.bos_token_id)+1
            label[:sep]=[-100]*sep
            self.label.append(label)
            
        print(len(self.data), "data")
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx], self.category[idx]
    
    def __len__(self):
        return len(self.data)

In [8]:
def collate_fn(batch):
    """
    For Same Sequence Length on Same Batch: Padding
    """
    max_len=0
    for data, _, _ in batch:
        if len(data)>max_len: max_len=len(data)
            
    datas=[]
    labels=[]
    categories=[]
    for data, label, category in batch:
        data.extend([tokenizer.pad_token_id]*(max_len-len(data)))
        datas.append(data)
        
        label.extend([tokenizer.pad_token_id]*(max_len-len(label)))
        labels.append(label)
        
        categories.append(category)
        
    return torch.tensor(datas), torch.tensor(labels), categories

In [9]:
# Train Set
dataset_train=WebNLGDataset(tokenizer=tokenizer, triples=triples_train, texts=texts_train, categories=categories_train)
dataloader_train=DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Dev Set
dataset_dev=WebNLGDataset(tokenizer=tokenizer, triples=triples_dev, texts=texts_dev, categories=categories_dev)
dataloader_dev=DataLoader(dataset_dev, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

18025 data
2258 data


In [10]:
class ControlPrefixesGPT2(nn.Module):
    """
    Control-Prefixes on GPT2
    """
    def __init__(self, config, categories, ctrlseqlen=2, preseqlen=5, hidden_dim=512):
        super().__init__()
        
        # Config of Pre-Trained LM
        self.config=config
        
        # Control-Prefixes: Attributes
        self.categories=categories
        print(self.categories)
        # Control Prefix Length
        self.ctrlseqlen=ctrlseqlen
        # General Prefix Length
        self.preseqlen=preseqlen
        
        # Embedding
        # Control
        self.wte_ctrl=nn.Embedding(ctrlseqlen*len(categories), self.config.n_embd)
        # General
        self.input_tokens=torch.arange(preseqlen).long()
        self.wte=nn.Embedding(len(categories)+preseqlen, self.config.n_embd)
        
        # Reparam
        self.control_trans=nn.Sequential(
            nn.Linear(self.config.n_embd, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.config.n_layer*2*self.config.n_embd)
        )
        #self.dropout=nn.Dropout(p=0.0)
        
        # Func: Get Prompt
        self.get_prompt=self.get_prompt_fn
        
    def get_prompt_fn(self, bsz=None, categories=None):
        # Control Prefix
        controls=[list(range(self.ctrlseqlen*self.categories.index(c), self.ctrlseqlen*(self.categories.index(c)+1))) for c in categories]
        controls=torch.tensor(controls).to(device)
        controls=self.wte_ctrl(controls)
        
        # General Prefix
        input_tokens=self.input_tokens.unsqueeze(0).expand(bsz, -1).to(device)
        input_tokens=self.wte(input_tokens)
        
        # [Control, General]
        input_tokens=torch.cat((controls, input_tokens), dim=1)
        past_key_values=self.control_trans(input_tokens)
        bsz, seqlen, _=past_key_values.shape
        past_key_values=past_key_values.view(bsz, seqlen, 2*self.config.n_layer, self.config.n_head, int(self.config.n_embd/self.config.n_head))
        #past_key_values=self.dropout(past_key_values)
        past_key_values=past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        
        return past_key_values
        
    def forward(self, input_ids, labels, categories):
        bsz=input_ids.shape[0]
        past_key_values_prompt=self.get_prompt(bsz=bsz, categories=categories)
        
        return past_key_values_prompt

In [11]:
# Model
model=ControlPrefixesGPT2(
    config=pretrained.config,
    categories=list(set(categories_train)),
    ctrlseqlen=ctrlseqlen,
    preseqlen=preseqlen,
    hidden_dim=hidden_dim
).to(device)

# Optim, Scheduler
optimizer=AdamW(model.parameters(), lr=lr)
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    # 3% of Total Steps
    num_warmup_steps=int(0.03*epochs*len(dataset_train)/batch_size),
    num_training_steps=int(epochs*len(dataset_train)/(accumulation_steps*batch_size))
)

# TensorBoard: Logging
writer=SummaryWriter()
step_global=0

for epoch in range(epochs):
    # Train Phase
    model.train()
    model.to(device)
    
    loss_train=0
    optimizer.zero_grad()
    
    for step, (data, label, category) in enumerate(dataloader_train):
        data=data.to(device)
        label=label.to(device)
        
        prompt=model(input_ids=data, labels=label, categories=category)
        outputs=pretrained(data, labels=label, past_key_values=prompt)
        
        loss=outputs[0]/accumulation_steps
        loss.backward()
        
        loss_train+=loss.item()
        
        if (step+1)%accumulation_steps==0:
            step_global+=1
            
            # TensorBoard
            writer.add_scalar(
                f'loss_train/c-prefixtuned_preseqlen{preseqlen}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}_lr{lr}',
                loss_train,
                step_global
            )
            # Console
            if step_global%500==0:
                print(f'epoch {epoch+1} step {step_global} loss_train {loss_train:.4f}')
            # Set Loss to 0
            loss_train=0
            
            optimizer.step()
            scheduler.step()
            
            optimizer.zero_grad()
            
    # Eval Phase
    model.eval()
    
    loss_eval=0
    
    with torch.no_grad():
        for step, (data, label, category) in enumerate(dataloader_dev):
            data=data.to(device)
            label=label.to(device)

            prompt=model(input_ids=data, labels=label, categories=category)
            outputs=pretrained(data, labels=label, past_key_values=prompt)

            loss=outputs[0]
            loss_eval+=loss.item()
        loss_eval=loss_eval/(step+1)
        
        # TensorBoard
        writer.add_scalar(
            f'loss_eval/c-prefixtuned_preseqlen{preseqlen}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}_lr{lr}',
            loss_eval,
            epoch+1
        )
        # Console
        print("=====")
        print(f'epoch {epoch+1} loss_eval {loss_eval:.4f}')
        print("=====")
        
    # Save Model
    model.to(torch.device('cpu'))
    torch.save(model, root_path+f'model/c-prefixtuned_preseqlen{preseqlen}_batch{int(accumulation_steps*batch_size)}_epoch{epoch+1}of{epochs}_lr{lr}.pt')

['Airport', 'WrittenWork', 'ComicsCharacter', 'Food', 'Building', 'City', 'University', 'SportsTeam', 'Astronaut', 'Monument']
epoch 1 step 500 loss_train 0.3020
epoch 1 step 1000 loss_train 0.4284
epoch 1 step 1500 loss_train 0.3429
epoch 1 step 2000 loss_train 0.3193
epoch 1 step 2500 loss_train 0.2812
epoch 1 step 3000 loss_train 0.2324
epoch 1 step 3500 loss_train 0.4601
=====
epoch 1 loss_eval 0.2978
=====
epoch 2 step 4000 loss_train 0.2087
epoch 2 step 4500 loss_train 0.3124
epoch 2 step 5000 loss_train 0.3243
epoch 2 step 5500 loss_train 0.2454
epoch 2 step 6000 loss_train 0.0871
epoch 2 step 6500 loss_train 0.1910
epoch 2 step 7000 loss_train 0.1636
=====
epoch 2 loss_eval 0.1656
=====
epoch 3 step 7500 loss_train 0.0796
epoch 3 step 8000 loss_train 0.1072
epoch 3 step 8500 loss_train 0.0587
epoch 3 step 9000 loss_train 0.1214
epoch 3 step 9500 loss_train 0.0949
epoch 3 step 10000 loss_train 0.1546
epoch 3 step 10500 loss_train 0.1567
=====
epoch 3 loss_eval 0.1131
=====
epoch