In [1]:
from datasets import load_dataset
import random

In [2]:
# take only 1000 sample from train
data = load_dataset("simple_questions_v2", split="train[:1000]")

In [3]:
# take only 1000 sample from train
valid = load_dataset("simple_questions_v2", split="test[:1000]")

In [4]:
i = random.randint(0, 899)
data[i]

{'id': '9',
 'subject_entity': 'www.freebase.com/m/0mgb6cl',
 'relationship': 'www.freebase.com/music/release_track/release',
 'object_entity': 'www.freebase.com/m/0f4zk3j',
 'question': 'What album was tibet released on\n'}

In [5]:
from transformers import AutoTokenizer

In [6]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [7]:
import json

In [8]:
with open("simple_questions_v2_freebase_simple_mapping.json", "r") as f:
    freebase_mappings = json.load(f)

In [63]:
def tokenize_relationship(relationship: str) -> str:
    rel = relationship.split("/")[1:]
    rel = " ".join(rel)
    rel = rel.replace("_", " ")
    rel = rel.replace("-", " ")
    return rel

def map_entity(entity: str, mapping: dict) -> str:
    entity = "/" + "/".join(entity.split("/")[1:])
    if entity not in mapping:
        return "unknown"

    return mapping[entity]

def to_sentence(sbj_entity: str, rel: str, obj_entity: str, freebase_mappings: dict) -> str:
    sbj_str = map_entity(sbj_entity, freebase_mappings)
    rel_str = tokenize_relationship(rel)
    obj_str = map_entity(obj_entity, freebase_mappings)
    
    return f"{sbj_str} {rel_str} {obj_str}"

class Mapper:
    def __init__(self, entities: list):
        self.unk_label = "UNK"
        self.unk_id = -1
        self.entity_to_id_map = {e: i for i, e in enumerate(entities)}
        self.id_to_entity_map = {i: e for e, i in self.entity_to_id_map.items()}
        self.entity_to_id_map[self.unk_label] = self.unk_id
        self.id_to_entity_map[-1] = self.unk_label
        
    def to_id(self, entity: str) -> int:
        if entity not in self.entity_to_id_map:
            return self.unk_id
        return self.entity_to_id_map[entity]

    def to_entity(self, id: int) -> int:
        if id not in self.id_to_entity_map:
            return self.unk_label

        return self.id_to_entity_map[id]
    
mapper = Mapper(freebase_mappings.keys())

def preprocess(batch):
    # 1. map entities to labels
    # 2. 

    # batched_sents = []
    # for datum in batch:
    #     sent = to_sentence(datum["subject_entity"], datum["relationship"], datum["object_entity"], freebase_mappings)
    #     batched_sents.append(sent)
    
    subject_entity = batch["subject_entity"]
    question = batch["question"]
    tokenized = tokenizer(text=question, truncation=True)

    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]

    subject_entity = ["/" + "/".join(e.split("/")[1:]) for e in subject_entity]
    
    mapped_subject_entity = [mapper.to_id(e) for e in subject_entity]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": mapped_subject_entity,
    }

In [64]:
mapper.to_id("/m/04j2svn")

1

In [65]:
mapper.to_entity(1)

'/m/04j2svn'

In [66]:
preprocessed_data = data.map(
    preprocess, 
    batched=True, 
    batch_size=4)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [67]:
preprocessed_data[13]

{'id': '13',
 'subject_entity': 'www.freebase.com/m/0hzqmtk',
 'relationship': 'www.freebase.com/film/film/language',
 'object_entity': 'www.freebase.com/m/03k50',
 'question': 'what is the language in which mera shikar was filmed in\n',
 'input_ids': [101,
  2054,
  2003,
  1996,
  2653,
  1999,
  2029,
  21442,
  2050,
  11895,
  6673,
  2001,
  6361,
  1999,
  102],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': 58101}

In [68]:
cleaned_data = preprocessed_data.map(
    lambda batch: batch, 
    batched=True, 
    batch_size=4, 
    remove_columns=[
        "id",
        "subject_entity",
        "relationship",
        "object_entity",
        "question"
    ])

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [69]:
cleaned_data[13]

{'input_ids': [101,
  2054,
  2003,
  1996,
  2653,
  1999,
  2029,
  21442,
  2050,
  11895,
  6673,
  2001,
  6361,
  1999,
  102],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': 58101}

In [70]:
import evaluate

accuracy = evaluate.load("accuracy")

In [71]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [73]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=len(freebase_mappings) + 1,
    id2label=mapper.entity_to_id_map,
    label2id=mapper.id_to_entity_map
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [74]:
training_args = TrainingArguments(
    output_dir="test_sbj_enty_class",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cleaned_data,
    eval_dataset=cleaned_data,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,7.425558,0.002
2,No log,6.910871,0.007
3,No log,6.258064,0.007
4,No log,5.734398,0.006
5,No log,5.552683,0.007


TrainOutput(global_step=315, training_loss=6.650751023065476, metrics={'train_runtime': 96.0358, 'train_samples_per_second': 52.064, 'train_steps_per_second': 3.28, 'total_flos': 46494741949632.0, 'train_loss': 6.650751023065476, 'epoch': 5.0})