In [1]:
import logging
import json

import torch
import torch.nn as nn
from transformers import GPT2Tokenizer

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [2]:
logging.disable(logging.CRITICAL)

device=torch.device('cuda:0')

model_path='./model/prefixtuned_preseqlen10_batch5_epoch10_lr5e-05.pt'

In [3]:
class PrefixTuning(nn.Module):
    
    def __init__(self, pretrained, preseqlen=5):
        super().__init__()
        
        self.pretrained=pretrained
        self.config=self.pretrained.config
        for param in self.pretrained.parameters():
            param.requires_grad=False
        
        self.input_tokens=torch.arange(preseqlen).long()
        self.wte=nn.Embedding(preseqlen, self.config.n_embd)
        self.control_trans=nn.Sequential(
            nn.Linear(self.config.n_embd, 512),
            nn.Tanh(),
            nn.Linear(512, 512),
            nn.Tanh(),
            nn.Linear(512, self.config.n_layer*2*self.config.n_embd)
        )
        self.dropout=nn.Dropout(p=0.0)
        
        self.get_prompt=self.get_prompt_fn
        
    def get_prompt_fn(self, bsz=None):
        input_tokens=self.input_tokens.unsqueeze(0).expand(bsz, -1).to(device)
        temp_control=self.wte(input_tokens)
        past_key_values=self.control_trans(temp_control)
        bsz, seqlen, _=past_key_values.shape
        past_key_values=past_key_values.view(bsz, seqlen, 2*self.config.n_layer, self.config.n_head, int(self.config.n_embd/self.config.n_head))
        past_key_values=self.dropout(past_key_values)
        past_key_values=past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        
        return past_key_values
        
    def forward(self, input_ids, labels):
        bsz=input_ids.shape[0]
        past_key_values_prompt=self.get_prompt(bsz=bsz)
        outputs=self.pretrained(input_ids=input_ids, labels=labels, past_key_values=past_key_values_prompt)
        
        return outputs

In [4]:
tokenizer=GPT2Tokenizer.from_pretrained('gpt2-medium')
model=torch.load(model_path).to(device)

In [5]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
print("pad_token_id:", tokenizer.pad_token_id)

pad_token_id: 50257


In [6]:
with open('./dataset/webnlg/train.json', 'r') as f:
    dict_train=json.load(f)
    f.close()

In [7]:
categories_seen=[]

for index, data in enumerate(dict_train['entries']):
    categories_seen.append(data[str(index+1)]['category'])
    
categories_seen=list(set(categories_seen))
print(len(categories_seen), "Categories in Train Set")
print(categories_seen)

10 Categories in Train Set
['University', 'ComicsCharacter', 'SportsTeam', 'WrittenWork', 'Building', 'Airport', 'Monument', 'Astronaut', 'City', 'Food']


In [8]:
with open('./dataset/webnlg/test.json', 'r') as f:
    dict_test=json.load(f)
    f.close()

In [9]:
categories_unseen=[]

triples_seen=[]
triples_unseen=[]

refs_seen=[]
refs_unseen=[]

In [10]:
for index, data in enumerate(dict_test['entries']):
    data=data[str(index+1)]
    
    triple_proc=""
    for triple in data['modifiedtripleset']:
        subj, prop, obj=triple['subject'], triple['property'], triple['object']
        triple_proc+="| {} : {} : {} ".format(subj, prop, obj)
        
    texts=data['lexicalisations']
    
    if data['category'] not in categories_seen:
        categories_unseen.append(data['category'])
        triples_unseen.append(triple_proc)
        refs_unseen.append([text['lex'] for text in texts])
        continue
        
    triples_seen.append(triple_proc)
    refs_seen.append([text['lex'] for text in texts])

In [11]:
categories_unseen=list(set(categories_unseen))
print(len(categories_unseen), "Unseen Categories")
print(categories_unseen)
print("=====")

print(len(triples_seen), "Seen Data")
print(len(triples_unseen), "Unseen Data")

5 Unseen Categories
['Politician', 'Artist', 'CelestialBody', 'MeanOfTransportation', 'Athlete']
=====
971 Seen Data
891 Unseen Data


In [12]:
scores_seen=[]
generations=""

print("Seen Categories")

past_key_values=model.get_prompt_fn(bsz=1)
for index, triple in enumerate(triples_seen):
    if (index+1)%100==0: print(index+1)
        
    input_=triple+tokenizer.bos_token
    len_=len(input_)
    for i in range(100):
        input_ids=tokenizer.encode(input_)
        output=model(input_ids=torch.tensor([input_ids]).to(device), labels=None)
        pred=tokenizer.decode(torch.argmax(output.logits[0][-1]))
        
        if pred==tokenizer.eos_token: break
            
        input_+=pred
    cand=input_[len_:]
    generations+=cand+"\n"
    
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs_seen[index]],
        cand.split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_seen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_seen)/len(scores_seen)))

with open('./generation/'+model_path.split("/")[-1][:-3]+"_Seen", 'w') as f:
    f.write(generations)
    f.close()

Seen Categories
100
200
300
400
500
600
700
800
900
BLEU Score: 50.09


In [13]:
scores_unseen=[]
generations=""

print("Unseen Categories")

past_key_values=model.get_prompt_fn(bsz=1)
for index, triple in enumerate(triples_unseen):
    if (index+1)%100==0: print(index+1)
        
    input_=triple+tokenizer.bos_token
    len_=len(input_)
    for i in range(100):
        input_ids=tokenizer.encode(input_)
        output=model(input_ids=torch.tensor([input_ids]).to(device), labels=None)
        pred=tokenizer.decode(torch.argmax(output.logits[0][-1]))
        
        if pred==tokenizer.eos_token: break
            
        input_+=pred
    cand=input_[len_:]
    generations+=cand+"\n"
    
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs_unseen[index]],
        cand.split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_unseen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_unseen)/len(scores_unseen)))

with open('./generation/'+model_path.split("/")[-1][:-3]+"_Unseen", 'w') as f:
    f.write(generations)
    f.close()

Unseen Categories
100
200
300
400
500
600
700
800
BLEU Score: 33.25


In [12]:
scores_seen=[]

with open('./generation/'+model_path.split("/")[-1][:-3]+"_Seen", 'r') as f:
    cands_seen=f.read().split("\n")
    f.close()

print("Seen Categories")

for index, refs in enumerate(refs_seen):
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs],
        cands_seen[index].split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_seen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_seen)/len(scores_seen)))

Seen Categories
BLEU Score: 50.09


In [13]:
scores_unseen=[]

with open('./generation/'+model_path.split("/")[-1][:-3]+"_Unseen", 'r') as f:
    cands_unseen=f.read().split("\n")
    f.close()

print("Unseen Categories")

for index, refs in enumerate(refs_unseen):
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs],
        cands_unseen[index].split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_unseen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_unseen)/len(scores_unseen)))

Unseen Categories
BLEU Score: 33.25
