In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_val = pd.read_csv("../data/dev_original.tsv", sep="\t")

In [4]:
df_val.head(1)

Unnamed: 0,marriage is pase.,Not everyone believes in marriage anymore.,support,(marriage; capable of; deceiving)(deceiving; created by; pase)(pase; used for; everyone)(everyone; capable of; believes)
0,marriage is the best for a family unit.,Marriage is a predictor of health and happiness.,support,(marriage; created by; love)(love; causes; hea...


In [3]:
df_train = pd.read_csv("../data/train_original.tsv", sep="\t")

In [4]:
df_train.columns = ["belief", "argument", "label", "explanation"]
df_val.columns = ["belief", "argument", "label", "explanation"]

In [5]:
df = pd.concat([df_train, df_val], axis=0)

In [7]:
train_val, test = train_test_split(df, random_state=1, test_size=0.1)

In [8]:
train, val = train_test_split(train_val, random_state=1, test_size=0.1)

In [9]:
train.to_csv("../data/train.tsv", sep="\t")

In [10]:
val.to_csv("../data/val.tsv", sep="\t")

In [11]:
test.to_csv("../data/test.tsv", sep="\t")

In [14]:
train

Unnamed: 0,belief,argument,label,explanation
142,urbanization creates high crime.,People migrate to cities in order to make money.,counter,(cities; capable of; job)(job; used for; make ...
1066,executives are not overpaid for the work they do.,Executives work quite hard and deserve their pay.,support,(executives; capable of; work quite hard)(work...
2334,Telemarketers has nothing to offer only to rip...,Not all telemarketers are scammers most have e...,counter,(telemarketers; capable of; enough to offer)(e...
94,Embryotic stem cells can save lives,Embryotic stem cells are something that is in ...,support,(embryotic stem cells; capable of; interest of...
1508,The olympics have lost their luster because wi...,Athletes are tested for drugs and can't compet...,counter,(athletes; capable of; tested for drugs)(teste...
...,...,...,...,...
518,zero tolerance could have deep consequences fo...,Zero tolerance implies harsher penalties.,support,(zero tolerance; causes; harsher punishment)(h...
1236,The government is obliged to ban naturopathy.,Naturopathy is experimental and the government...,support,(naturopathy; is a; experimental)(experimental...
1965,Safe spaces are a redundant and unnecessary pr...,Some people have no support or guidance and ne...,counter,(safe spaces; capable of; support or guidance)...
681,Someone with a history of criminal behavior sh...,Repeat offenders have not learned their lesson...,support,(repeat offenders; has property; criminal beha...


In [19]:
class ExplaGraphs(Dataset):
    def __init__(self, model_name, split="train", use_graphs=True):
        print(f"Use graph explanations = {use_graphs}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        df = pd.read_csv(f"../data/{split}.tsv", sep="\t", header=0, index_col=0)
        premises, arguments, self.labels, explanations = df.to_numpy().T
        self.label_converter = {"counter": 0, "support": 1}
        self.label_inverter = {0: "counter", 1: "support"}
        explanations = [self.clean_string(x) for x in explanations]
        if use_graphs == True:
            self.features = [prem + " [SEP] " + arg + " [SEP] " + exp for prem,arg,exp in zip(premises, arguments, explanations)]
        else:
            self.features = [prem + " [SEP] " + arg for prem,arg in zip(premises, arguments)]
            
        encodings = self.tokenizer(self.features, truncation=True, padding=True)
        self.input_ids, self.attention_masks = encodings["input_ids"], encodings["attention_mask"]
        
    def clean_string(self, x):
        x = x.replace(")(", ", ")
        return x.replace("(", "").replace(")","").replace(";", "")
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return torch.LongTensor(self.input_ids[idx]), torch.BoolTensor(self.attention_masks[idx]), self.label_converter[self.labels[idx]]

In [20]:
train = ExplaGraphs("bert-base-uncased", split="train")

Use graph explanations = True


In [21]:
x = train.features[1]

In [22]:
x

'executives are not overpaid for the work they do. [SEP] Executives work quite hard and deserve their pay. [SEP] executives capable of work quite hard, work quite hard capable of deserve their pay, deserve their pay synonym of not overpaid'