In [1]:
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup

In [2]:
device=torch.device('cuda:3')

In [3]:
tokenizer=GPT2Tokenizer.from_pretrained('gpt2-large')
pretrained=GPT2LMHeadModel.from_pretrained('gpt2-large')

In [4]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
pretrained.resize_token_embeddings(len(tokenizer))

Embedding(50258, 1280)

In [5]:
print("bos_token:", tokenizer.bos_token)
print("eos_token:", tokenizer.eos_token)
print("pad_token_id:", tokenizer.pad_token_id)

bos_token: <|endoftext|>
eos_token: <|endoftext|>
pad_token_id: 50257


In [6]:
with open('./dataset/webnlg/train.json', 'r') as f:
    dict_train=json.load(f)
    f.close()

In [7]:
data_category=[]
data_triple=[]
data_text=[]

for index, data in enumerate(dict_train['entries']):
    triples=data[str(index+1)]['modifiedtripleset']
    triple_proc=""
    for triple in triples:
        subj, prop, obj=triple['subject'], triple['property'], triple['object']
        triple_proc+="| {} : {} : {} ".format(subj, prop, obj)
        
    texts=data[str(index+1)]['lexicalisations']
    for text in texts:
        if text['comment']!="good": continue
            
        data_category.append(data[str(index+1)]['category'])
        data_triple.append(triple_proc)
        data_text.append(text['lex'])
        
print(len(data_category), 'Categories')
print(len(data_triple), "Triples")
print(len(data_text), "Texts")

18025 Categories
18025 Triples
18025 Texts


In [8]:
# General Prefix Length
gen_seqlen=10

# Hyperparams
batch_size=5
accumulation_steps=1
epochs=10
lr=7e-5

In [9]:
class D2TDataset(Dataset):
    def __init__(self, tokenizer, data_category, data_triple, data_text):
        
        self.category=data_category
        self.data=[]
        self.label=[]
        
        for index, triple in enumerate(data_triple):
            data=tokenizer.encode(triple+tokenizer.bos_token+data_text[index]+tokenizer.eos_token)
            self.data.append(data)
            
            label=tokenizer.encode(triple+tokenizer.bos_token+data_text[index]+tokenizer.eos_token)
            sep=label.index(tokenizer.bos_token_id)+1
            label[:sep]=[-100]*sep
            self.label.append(label)
            
        print(len(self.data), "Data")
    
    def __getitem__(self, idx):
        return self.category[idx], self.data[idx], self.label[idx]
    
    def __len__(self):
        return len(self.data)

In [10]:
def collate_fn(batch):
    max_len=0
    for _, data, _ in batch:
        if len(data)>max_len: max_len=len(data)
            
    categories=[]
    datas=[]
    labels=[]
    for category, data, label in batch:
        categories.append(category)
        
        data.extend([tokenizer.pad_token_id]*(max_len-len(data)))
        datas.append(data)
        
        label.extend([tokenizer.pad_token_id]*(max_len-len(label)))
        labels.append(label)
        
    return categories, torch.tensor(datas), torch.tensor(labels)

In [11]:
dataset=D2TDataset(tokenizer=tokenizer, data_category=data_category, data_triple=data_triple, data_text=data_text)
dataloader=DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

18025 Data


In [12]:
class ControlPrefixes(nn.Module):
    
    def __init__(self, pretrained, controls, gen_seqlen=5):
        super().__init__()
        
        # Pre-Trained LM
        self.pretrained=pretrained
        self.config=self.pretrained.config
        for param in self.pretrained.parameters():
            param.requires_grad=False
            
        # Control Prefixes: Attributes
        self.controls=controls
        print(self.controls)
        
        # General Prefix Length
        self.gen_seqlen=gen_seqlen
        
        self.input_tokens=torch.arange(gen_seqlen).long()
        self.wte=nn.Embedding(len(controls)+gen_seqlen, self.config.n_embd)
        self.control_trans=nn.Sequential(
            nn.Linear(self.config.n_embd, 512),
            nn.Tanh(),
            nn.Linear(512, 512),
            nn.Tanh(),
            nn.Linear(512, self.config.n_layer*2*self.config.n_embd)
        )
        self.dropout=nn.Dropout(p=0.0)
        
        self.get_prompt=self.get_prompt_fn
        
    def get_prompt_fn(self, bsz, controls):
        # Control Prefixes
        controls=[self.gen_seqlen+self.controls.index(c) for c in controls]
        controls=torch.tensor(controls).unsqueeze(1)
        # General Prefix
        input_tokens=self.input_tokens.unsqueeze(0).expand(bsz, -1)
        # [Control Prefixes, General Prefix]
        input_tokens=torch.cat((controls, input_tokens), dim=1).to(device)
        #print(input_tokens)
        
        temp_control=self.wte(input_tokens)
        past_key_values=self.control_trans(temp_control)
        bsz, seqlen, _=past_key_values.shape
        past_key_values=past_key_values.view(bsz, seqlen, 2*self.config.n_layer, self.config.n_head, int(self.config.n_embd/self.config.n_head))
        past_key_values=self.dropout(past_key_values)
        past_key_values=past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        
        return past_key_values
        
    def forward(self, controls, input_ids, labels):        
        bsz=input_ids.shape[0]
        past_key_values_prompt=self.get_prompt(bsz=bsz, controls=controls)
        outputs=self.pretrained(input_ids=input_ids, labels=labels, past_key_values=past_key_values_prompt)
        
        return outputs

In [13]:
model=ControlPrefixes(pretrained=pretrained, controls=list(set(data_category)), gen_seqlen=gen_seqlen)

optimizer=AdamW(model.parameters(), lr=lr)
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(0.03*epochs*len(dataset)/batch_size),
    num_training_steps=int(epochs*len(dataset)/(accumulation_steps*batch_size))
)

model.to(device)
model.train()

for epoch in range(epochs):
    loss_total=0
    optimizer.zero_grad()
    for step, (category, data, label) in enumerate(dataloader):
        data=data.to(device)
        label=label.to(device)
        
        outputs=model(controls=category, input_ids=data, labels=label)
        
        loss=outputs[0]/accumulation_steps
        loss.backward()
        
        loss_total+=loss.item()
        
        if (step+1)%accumulation_steps==0:
            if (step+1)%(300*accumulation_steps)==0:
                print(f'epoch {epoch+1} step {(step+1)/accumulation_steps} loss {loss_total:.4f}')
            loss_total=0
            
            optimizer.step()
            scheduler.step()
            
            optimizer.zero_grad()

model.eval()
model.to(torch.device('cpu'))

torch.save(model, './model/'+f'control-prefixes-tuned_Large_preseqlen{gen_seqlen}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}_lr{lr}.pt')

['ComicsCharacter', 'WrittenWork', 'Building', 'Monument', 'SportsTeam', 'Astronaut', 'Food', 'City', 'Airport', 'University']
epoch 1 step 300.0 loss 0.8324
epoch 1 step 600.0 loss 0.5727
epoch 1 step 900.0 loss 0.2231
epoch 1 step 1200.0 loss 0.5914
epoch 1 step 1500.0 loss 0.2977
epoch 1 step 1800.0 loss 0.2371
epoch 1 step 2100.0 loss 0.4438
epoch 1 step 2400.0 loss 0.4288
epoch 1 step 2700.0 loss 0.2354
epoch 1 step 3000.0 loss 0.2000
epoch 1 step 3300.0 loss 0.4080
epoch 1 step 3600.0 loss 0.2607
epoch 2 step 300.0 loss 0.1324
epoch 2 step 600.0 loss 0.3924
epoch 2 step 900.0 loss 0.1692
epoch 2 step 1200.0 loss 0.2124
epoch 2 step 1500.0 loss 0.2115
epoch 2 step 1800.0 loss 0.2914
epoch 2 step 2100.0 loss 0.3101
epoch 2 step 2400.0 loss 0.2478
epoch 2 step 2700.0 loss 0.2288
epoch 2 step 3000.0 loss 0.3006
epoch 2 step 3300.0 loss 0.1820
epoch 2 step 3600.0 loss 0.3325
epoch 3 step 300.0 loss 0.1292
epoch 3 step 600.0 loss 0.1104
epoch 3 step 900.0 loss 0.1022
epoch 3 step 1200.