# Commonsense Path Generator for Connecting Entities
In this notebook, we show how to use our proposed path generator to generate a commonsense relational path
for connecting a pair of entities. You can then use the generator as a plug-in module for providing structured
evidence to any downstream task.

## Preparing generator

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
# assert transformers.__version__ == '2.8.0'
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model

In [2]:
device = 'cuda:0'

In [3]:
# Define the generator model
class Generator(nn.Module):
    def __init__(self, gpt, config, max_len=31):
        super(Generator, self).__init__()
        self.gpt = gpt
        self.config = config
        self.max_len = max_len
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, inputs):
        # input: [batch, seq]
        context_len = inputs.size(1)
        generated = inputs
        next_token = inputs
        past = None
        with torch.no_grad():
            for step in range(self.max_len):
                outputs = self.gpt(next_token, past=past)
                hidden = outputs[0][:, -1]
                past = outputs[1]
                next_token_logits = self.lm_head(hidden)
                next_logits, next_token = next_token_logits.topk(k=1, dim=1)
                generated = torch.cat((generated, next_token), dim=1)
        return generated 

### Download a well-trained generator from this link to your local workspace:
https://drive.google.com/file/d/1dQNxyiP4g4pdFQD6EPMQdzNow9sQevqD/view?usp=sharing

In [4]:
# lm_type = 'gpt2'
model_path = './learning-generator/gpt2/'
config = GPT2Config.from_pretrained(model_path)
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
tokenizer.add_tokens(['<PAD>'])
tokenizer.add_tokens(['<SEP>'])
tokenizer.add_tokens(['<END>'])
gpt = GPT2Model.from_pretrained(model_path)
config.vocab_size = len(tokenizer)
gpt.resize_token_embeddings(len(tokenizer))

pretrain_generator_ckpt = "./learning-generator/ckpt/commonsense-path-generator.ckpt"
# pretrain_generator_ckpt = "/your_path_to_the_download_checkpoint/commonsense-path-generator.ckpt"
generator = Generator(gpt, config)
generator.load_state_dict(torch.load(pretrain_generator_ckpt, map_location='cpu'))

<All keys matched successfully>

In [5]:
generator.to(device)

Generator(
  (gpt): GPT2Model(
    (wte): Embedding(50260, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), 

In [6]:
def prepare_input(head_entity, tail_entity, input_len=16):
    head_entity = head_entity.replace('_', ' ')
    tail_entity = tail_entity.replace('_', ' ')
    input_token = tail_entity + '<SEP>' + head_entity
    input_id = tokenizer.encode(input_token, add_special_tokens=False)[:input_len]
    input_id += [tokenizer.convert_tokens_to_ids('<PAD>')] * (input_len - len(input_id))
    return torch.tensor([input_id], dtype=torch.long)

In [13]:
def connect_entities(head_entity, tail_entity):
    gen_input = prepare_input(head_entity, tail_entity).to(device)
    gen_output = generator(gen_input)
    path = tokenizer.decode(gen_output[0].tolist(), skip_special_tokens=True)
    print('just generated:', path)
    path = ' '.join(path.replace('<PAD>', '').split())
    print('remove <PAD>:', path)
    # return path
    return path[path.index('<SEP>')+6:]  # re

### Usage Example
- Input: A pair of entities you want to connect, expressed in natural language.
- Output: A relational path in the form of (head_entiy, relation1, intermedia_entity1, relation2, ..., tail_entity).

In [15]:
head_entity = 'curiosity'
tail_entity = 'hear_news'
path = connect_entities(head_entity, tail_entity)
print(path)

just generated: hear news <SEP> curiosity <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> causesdesire find information hassubevent read _hasprerequisite hear news <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
remove <PAD>: hear news <SEP> curiosity causesdesire find information hassubevent read _hasprerequisite hear news
hear news <SEP> curiosity causesdesire find information hassubevent read _hasprerequisite hear news


In [9]:
from itertools import permutations

In [10]:
# entities = ['farmland', 'countryside', 'midwest', 'illinois']
entities = ['revolving door', 'security measure', 'bank']
path_list = []
for e1, e2 in permutations(entities, 2):
    path = connect_entities(e1, e2)
    path_list.append(path)

for path in path_list:
    print(path)

security measure <SEP> revolving door <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> atlocation front door isa security measure <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
bank <SEP> revolving door <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> atlocation bank <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
revolving door <SEP> security measure <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> _isa revolving door <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
bank <SEP> security measure <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> _isa check atlocation bank <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <P

In [18]:
# entities = ['geological_feature ', 'fungus', 'cave','grow']
entities = ['overpopulation', 'organism', 'resources', 'ecosystem']
path_list = []
for e1, e2 in permutations(entities, 2):
    path = connect_entities(e1, e2)
    path_list.append(path)

for path in path_list:
    print(path)

overpopulation _causes overpopulation isa organism
overpopulation _causes overpopulation _notdesires person desires resources
overpopulation isa ecosystem
organism _isa species _hascontext overpopulation
organism _isa ecosystem hascontext resources
organism _isa ecosystem
resources _hasprerequisite reproducing causes overpopulation
resources isa organism
resources _hasprerequisite living _usedfor ecosystem
ecosystem _atlocation overpopulation
ecosystem isa organism
ecosystem _atlocation resources
