# Commonsense Path Generator for Connecting Entities
In this notebook, we show how to use our proposed path generator to generate a commonsense relational path
for connecting a pair of entities. You can then use the generator as a plug-in module for providing structured
evidence to any downstream task.

## Preparing generator

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
# assert transformers.__version__ == '2.8.0'
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model

In [3]:
# Define the generator model
class Generator(nn.Module):
    def __init__(self, gpt, config, max_len=31):
        super(Generator, self).__init__()
        self.gpt = gpt
        self.config = config
        self.max_len = max_len
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, inputs):
        # input: [batch, seq]
        context_len = inputs.size(1)
        generated = inputs
        next_token = inputs
        past = None
        with torch.no_grad():
            for step in range(self.max_len):
                outputs = self.gpt(next_token, past=past)
                hidden = outputs[0][:, -1]
                past = outputs[1]
                next_token_logits = self.lm_head(hidden)
                next_logits, next_token = next_token_logits.topk(k=1, dim=1)
                generated = torch.cat((generated, next_token), dim=1)
        return generated 

### Download a well-trained generator from this link to your local workspace:
https://drive.google.com/file/d/1dQNxyiP4g4pdFQD6EPMQdzNow9sQevqD/view?usp=sharing

In [4]:
lm_type = 'gpt2'
config = GPT2Config.from_pretrained(lm_type)
tokenizer = GPT2Tokenizer.from_pretrained(lm_type)
tokenizer.add_tokens(['<PAD>'])
tokenizer.add_tokens(['<SEP>'])
tokenizer.add_tokens(['<END>'])
gpt = GPT2Model.from_pretrained(lm_type)
config.vocab_size = len(tokenizer)
gpt.resize_token_embeddings(len(tokenizer))
pretrain_generator_ckpt = "/your_path_to_the_download_checkpoint/commonsense-path-generator.ckpt"
generator = Generator(gpt, config)
generator.load_state_dict(torch.load(pretrain_generator_ckpt, map_location='cpu'))

OSError: Can't load config for 'gpt2'. Make sure that:

- 'gpt2' is a correct model identifier listed on 'https://huggingface.co/models'

- or 'gpt2' is the correct path to a directory containing a config.json file



In [4]:
def prepare_input(head_entity, tail_entity, input_len=16):
    head_entity = head_entity.replace('_', ' ')
    tail_entity = tail_entity.replace('_', ' ')
    input_token = tail_entity + '<SEP>' + head_entity
    input_id = tokenizer.encode(input_token, add_special_tokens=False)[:input_len]
    input_id += [tokenizer.convert_tokens_to_ids('<PAD>')] * (input_len - len(input_id))
    return torch.tensor([input_id], dtype=torch.long)

In [5]:
def connect_entities(head_entity, tail_entity):
    gen_input = prepare_input(head_entity, tail_entity)
    gen_output = generator(gen_input)
    path = tokenizer.decode(gen_output[0].tolist(), skip_special_tokens=True)
    path = ' '.join(path.replace('<PAD>', '').split())
    return path[path.index('<SEP>')+6:]

### Usage Example
- Input: A pair of entities you want to connect, expressed in natural language.
- Output: A relational path in the form of (head_entiy, relation1, intermedia_entity1, relation2, ..., tail_entity).

In [6]:
head_entity = 'curiosity'
tail_entity = 'hear_news'
path = connect_entities(head_entity, tail_entity)
print(path)

curiosity causesdesire find information hassubevent read _hasprerequisite hear news
