In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_val = pd.read_csv("../data/dev.tsv", sep="\t")

In [3]:
len(df_val)

397

In [4]:
df_train = pd.read_csv("../data/train.tsv", sep="\t")

In [5]:
len(df_train)

2367

In [6]:
df = pd.concat([df_train, df_val])

In [7]:
len(df)

2764

In [9]:
train_val, test = train_test_split(df, random_state=1, test_size=0.1)

In [10]:
train, val = train_test_split(train_val, random_state=1, test_size=0.1)

In [15]:
train.to_csv("../data/train.tsv", sep="\t")

In [16]:
val.to_csv("../data/val.tsv", sep="\t")

In [17]:
test.to_csv("../data/test.tsv", sep="\t")

In [None]:
class ExplaGraphs(Dataset):
    def __init__(self, model_name, split="train", use_graphs=False):
        print(f"Use graph explanations = {use_graphs}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        df = pd.read_csv(f"../data/{split}.tsv", sep="\t")
        premises, arguments, self.labels, explanations = df.to_numpy().T
        self.label_converter = {"counter": 0, "support": 1}
        self.label_inverter = {0: "counter", 1: "support"}
        explanations = [self.clean_string(x) for x in explanations]
        if use_graphs == True:
            self.features = [prem + " [SEP] " + arg + " [SEP] " + exp for prem,arg,exp in zip(premises, arguments, explanations)]
        else:
            self.features = [prem + " [SEP] " + arg for prem,arg in zip(premises, arguments)]
            
        encodings = self.tokenizer(self.features, truncation=True, padding=True)
        self.input_ids, self.attention_masks = encodings["input_ids"], encodings["attention_mask"]
        
    def clean_string(self, x):
        x = x.replace(")(", ", ")
        return x.replace("(", "").replace(")","").replace(";", "")
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return torch.LongTensor(self.input_ids[idx]), torch.BoolTensor(self.attention_masks[idx]), self.label_converter[self.labels[idx]]

In [None]:
train = ExplaGraphs("bert-base-uncased", split="dev")

In [None]:
x = train.features[1]

In [None]:
x