In [1]:
import sys
import logging
import json
import types

import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [2]:
# Root Path
root_path='/root/research/Graph-To-Text/'
# Model Path
model_path='./model/c-prefixtuned_preseqlen10_batch5_epoch5of5_lr7e-05.pt'
# Device: (Single) GPU
device=torch.device('cuda:3')
# Beam Search
num_beams=6

# Import Modules
sys.path.append(root_path+'transformers/')
from generation_utils import generate, beam_search

# Debugger
logging.disable(logging.CRITICAL)

In [3]:
with open(root_path+'dataset/webnlg/train.json', 'r') as f:
    dict_train=json.load(f)
    f.close()
    
categories_seen=[]

for index, data in enumerate(dict_train['entries']):
    categories_seen.append(data[str(index+1)]['category'])
    
categories_seen=list(set(categories_seen))
print(len(categories_seen), "Categories in Train Set")
print(categories_seen)

10 Categories in Train Set
['Building', 'University', 'ComicsCharacter', 'WrittenWork', 'Monument', 'Astronaut', 'City', 'SportsTeam', 'Food', 'Airport']


In [4]:
with open(root_path+'dataset/webnlg/test.json', 'r') as f:
    dict_test=json.load(f)
    f.close()
    
categories_unseen=[]

controls_seen=[]
triples_seen=[]
refs_seen=[]

controls_unseen=[]
triples_unseen=[]
refs_unseen=[]

for index, data in enumerate(dict_test['entries']):
    data=data[str(index+1)]
    
    triple_proc=""
    for triple in data['modifiedtripleset']:
        subj, prop, obj=triple['subject'], triple['property'], triple['object']
        triple_proc+="| {} : {} : {} ".format(subj, prop, obj)
        
    texts=data['lexicalisations']
    
    if data['category'] not in categories_seen:
        categories_unseen.append(data['category'])
        
        controls_unseen.append(data['category'])
        triples_unseen.append(triple_proc)
        refs_unseen.append([text['lex'] for text in texts])
        continue
        
    controls_seen.append(data['category'])
    triples_seen.append(triple_proc)
    refs_seen.append([text['lex'] for text in texts])
    
categories_unseen=list(set(categories_unseen))
print(len(categories_unseen), "Unseen Categories")
print(categories_unseen)
print("=====")

print(len(triples_seen), "Seen Data")
print(len(triples_unseen), "Unseen Data")

5 Unseen Categories
['CelestialBody', 'Politician', 'Artist', 'Athlete', 'MeanOfTransportation']
=====
971 Seen Data
891 Unseen Data


In [5]:
class ControlPrefixesGPT2(nn.Module):
    """
    Control-Prefixes on GPT2
    """
    def __init__(self, config, categories, ctrlseqlen=2, preseqlen=5, hidden_dim=512):
        super().__init__()
        
        # Config of Pre-Trained LM
        self.config=config
        
        # Control-Prefixes: Attributes
        self.categories=categories
        print(self.categories)
        # Control Prefix Length
        self.ctrlseqlen=ctrlseqlen
        # General Prefix Length
        self.preseqlen=preseqlen
        
        # Embedding
        # Control
        self.wte_ctrl=nn.Embedding(ctrlseqlen*len(categories), self.config.n_embd)
        # General
        self.input_tokens=torch.arange(preseqlen).long()
        self.wte=nn.Embedding(len(categories)+preseqlen, self.config.n_embd)
        
        # Reparam
        self.control_trans=nn.Sequential(
            nn.Linear(self.config.n_embd, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.config.n_layer*2*self.config.n_embd)
        )
        #self.dropout=nn.Dropout(p=0.0)
        
        # Func: Get Prompt
        self.get_prompt=self.get_prompt_fn
        
    def get_prompt_fn(self, bsz=None, categories=None):
        # Control Prefix
        controls=[list(range(self.ctrlseqlen*self.categories.index(c), self.ctrlseqlen*(self.categories.index(c)+1))) for c in categories]
        controls=torch.tensor(controls).to(device)
        controls=self.wte_ctrl(controls)
        
        # General Prefix
        input_tokens=self.input_tokens.unsqueeze(0).expand(bsz, -1).to(device)
        input_tokens=self.wte(input_tokens)
        
        # [Control, General]
        input_tokens=torch.cat((controls, input_tokens), dim=1)
        past_key_values=self.control_trans(input_tokens)
        bsz, seqlen, _=past_key_values.shape
        past_key_values=past_key_values.view(bsz, seqlen, 2*self.config.n_layer, self.config.n_head, int(self.config.n_embd/self.config.n_head))
        #past_key_values=self.dropout(past_key_values)
        past_key_values=past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        
        return past_key_values
        
    def forward(self, input_ids, labels, categories):
        bsz=input_ids.shape[0]
        past_key_values_prompt=self.get_prompt(bsz=bsz, categories=categories)
        
        return past_key_values_prompt

In [6]:
tokenizer=GPT2Tokenizer.from_pretrained('gpt2-large')
pretrained=GPT2LMHeadModel.from_pretrained('gpt2-large').to(device)
model=torch.load(root_path+model_path).to(device)

# Add PAD Token: [PAD]
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
print("pad_token_id:", tokenizer.pad_token_id)

# Bind Customized Generation Function
pretrained.generate=types.MethodType(generate, pretrained)
pretrained.beam_search=types.MethodType(beam_search, pretrained)

pad_token_id: 50257


In [7]:
scores_seen=[]
generations=""

print("Seen Categories")

for index, triple in enumerate(triples_seen):
    if (index+1)%100==0: print(index+1)
        
    prefix=model.get_prompt(bsz=num_beams, categories=[controls_seen[index]]*num_beams)
    input_ids=tokenizer.encode(triple+tokenizer.bos_token)
    output=pretrained.generate(torch.tensor([input_ids]).to(device), max_length=500, num_beams=num_beams, early_stoping=True, prefix=prefix)
    cand=tokenizer.decode(output[0], skip_special_tokens=True)[len(triple):]
    generations+=cand+"\n"
    
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs_seen[index]],
        cand.split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_seen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_seen)/len(scores_seen)))

with open(root_path+'generation/'+model_path.split("/")[-1][:-3]+'_Seen', 'w') as f:
    f.write(generations)
    f.close()

Seen Categories
100
200
300
400
500
600
700
800
900
BLEU Score: 57.72


In [8]:
controls_mapping={
    'Artist': 'WrittenWork',
    'Athlete': 'SportsTeam',
    'CelestialBody': 'Astronaut',
    'MeanOfTransportation': 'Airport',
    'Politician': 'Monument'
}
controls_unseen=[controls_mapping[c] for c in controls_unseen]

In [9]:
scores_unseen=[]
generations=""

print("Unseen Categories")

for index, triple in enumerate(triples_unseen):
    if (index+1)%100==0: print(index+1)
        
    prefix=model.get_prompt(bsz=num_beams, categories=[controls_unseen[index]]*num_beams)
    input_ids=tokenizer.encode(triple+tokenizer.bos_token)
    output=pretrained.generate(torch.tensor([input_ids]).to(device), max_length=500, num_beams=num_beams, early_stoping=True, prefix=prefix)
    cand=tokenizer.decode(output[0], skip_special_tokens=True)[len(triple):]
    generations+=cand+"\n"
    
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs_unseen[index]],
        cand.split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_unseen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_unseen)/len(scores_unseen)))

with open(root_path+'generation/'+model_path.split("/")[-1][:-3]+'_Unseen', 'w') as f:
    f.write(generations)
    f.close()

Unseen Categories
100
200
300
400
500
600
700
800
BLEU Score: 30.97


In [10]:
scores_seen=[]

with open(root_path+'generation/'+model_path.split("/")[-1][:-3]+'_Seen', 'r') as f:
    cands_seen=f.read().split("\n")
    f.close()

print("Seen Categories")

for index, refs in enumerate(refs_seen):
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs],
        cands_seen[index].split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_seen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_seen)/len(scores_seen)))

Seen Categories
BLEU Score: 57.72


In [11]:
scores_unseen=[]

with open(root_path+'generation/'+model_path.split("/")[-1][:-3]+'_Unseen', 'r') as f:
    cands_unseen=f.read().split("\n")
    f.close()

print("Unseen Categories")

for index, refs in enumerate(refs_unseen):
    bleu_score=sentence_bleu(
        [ref.split() for ref in refs],
        cands_unseen[index].split(),
        smoothing_function=SmoothingFunction().method4
    )
    scores_unseen.append(bleu_score)
print("BLEU Score: {:.2f}".format(100*sum(scores_unseen)/len(scores_unseen)))

Unseen Categories
BLEU Score: 30.97
