<a href="https://colab.research.google.com/github/Mozzer2310/text-mining-cwk/blob/sam-experiments/DL_Relation_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Get the Dataset
We need to install the `datasets` module to download the [DialogRE](https://huggingface.co/datasets/dataset-org/dialog_re) dataset.

In [129]:
! pip install datasets -q

Then we can download the dataset.

In [130]:
from datasets import load_dataset

dataset = load_dataset("dataset-org/dialog_re", download_mode="force_redownload", trust_remote_code=True)

dialog_re.py:   0%|          | 0.00/4.83k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/752k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/726k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1073 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/357 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/358 [00:00<?, ? examples/s]

Then view the Dataset and its contents.

In [131]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 1073
    })
    test: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 357
    })
    validation: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 358
    })
})

In [132]:
dataset['train'][0]

{'dialog': ["Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted!",
  'Speaker 2: There you go! Someone came!',
  "Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner!",
  'Speaker 3: Hi, glad you could come.',
  'Speaker 2: Please, come in.',
  "Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class.",
  'Speaker 2: Oh yes, yes... let me... take your coat.',
  "Speaker 4: Thanks... uh... I'm so sorry about Ross, it's...",
  'Speaker 2: At least he died doing what he loved... watching blimps.',
  'Speaker 1: Who is he?',
  'Speaker 2: Some guy, Tom Gordon.',
  "Speaker 1: I don't remember him, but then again I touched so many lives.",
  'Speaker 3: So, did you know Ross well?',
  "Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone?",
  'Speaker 3: Yes, he is. Me.',
  'S

## Preprocess the Data
1. Reformat the dataset so each sample (relation) is extracted from each item in the dataset
2. Proprocess each sample getting the tokens and the positional indices of the entities
3. Create a PyTorch dataset for the data

### Reformat the Dataset
Convert the dataset so that each item contains a singular relation.

In [133]:
def reformat_dataset(dataset, add_triggers=True):
    reformatted_dataset = []

    for item in dataset:
        dialog = item['dialog']
        relation_data = item['relation_data']

        # Join the dialog into a single string
        all_dialog = ' '.join(dialog)

        samples = []
        for x, y, r, t in zip(relation_data['x'], relation_data['y'], relation_data['r'], relation_data['t']):
            sample = {'dialog': all_dialog, 'x': x, 'y': y, 'relation': r}
            if add_triggers:
                sample['trigger'] = t
            samples.append(sample)

        reformatted_dataset.extend(samples)

    return reformatted_dataset

In [134]:
reformatted_dataset = {}
for split in dataset.keys():
    reformatted_dataset[split] = reformat_dataset(dataset[split], add_triggers=False)

In [135]:
print(reformatted_dataset['train'][0])

{'dialog': "Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted! Speaker 2: There you go! Someone came! Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner! Speaker 3: Hi, glad you could come. Speaker 2: Please, come in. Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class. Speaker 2: Oh yes, yes... let me... take your coat. Speaker 4: Thanks... uh... I'm so sorry about Ross, it's... Speaker 2: At least he died doing what he loved... watching blimps. Speaker 1: Who is he? Speaker 2: Some guy, Tom Gordon. Speaker 1: I don't remember him, but then again I touched so many lives. Speaker 3: So, did you know Ross well? Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone? Speaker 3: Yes, he is. Me. Speaker 4: What? You... You... Oh! Can I ask you a personal question? Ho-how 

### Preprocess each Sample

In [136]:
import re
import torch
from transformers import AutoTokenizer

In [137]:
MODEL_NAME = "bert-base-uncased"

In [138]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
SUBJ_TOKEN = '[SUBJ]'
OBJ_TOKEN = '[OBJ]'
special_tokens_dict = {'additional_special_tokens': [SUBJ_TOKEN, OBJ_TOKEN]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [139]:
def preprocess_sample(sample):
    SEP_TOKEN = '[SEP]'
    dialog = sample['dialog']
    x = sample['x']
    y = sample['y']

    dialog1 = dialog.replace(x, '[SUBJ]')
    dialog2 = dialog1.replace(y, '[OBJ]')

    text = f"{dialog2} {SEP_TOKEN} {x} {SEP_TOKEN} {y}"
    if 'trigger' in sample:
        trigger = ', '.join(sample['trigger'])
        text += f" {SEP_TOKEN} {trigger}"

    tokens = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")

    # Find entity positions
    words = tokenizer.tokenize(text)
    e1_positions = [i for i, x in enumerate(words) if x == SUBJ_TOKEN]
    e2_positions = [i for i, x in enumerate(words) if x == OBJ_TOKEN]

    # Do we need to put the positions in a tensor?
    return tokens, (torch.tensor(e1_positions), torch.tensor(e2_positions))

    # TODO: replace instances of x and y with [SUBJ] and [OBJ] tokens
    # TODO: append [SEP] x [SEP] y {[SEP] trigger}
    # TODO: tokenize with relevant tokenizer add special tokens first
    # TODO: get indices of [SUBJ] and [OBJ] in tokenized dialog

In [140]:
preprocess_sample(reformatted_dataset['train'][0])

({'input_ids': tensor([[  101,  5882,  1015,  1024,  2009,  1005,  1055,  2042,  2019,  3178,
           1998,  2025,  2028,  1997,  2026, 19846,  2038,  3491,  2039,   999,
           1045,  2425,  2017,  1010,  2043,  1045,  2941,  3280,  2070,  2111,
           2024,  6069,  2131,  5667, 11171,   999, 30522,  1024,  2045,  2017,
           2175,   999,  2619,  2234,   999,  5882,  1015,  1024,  7929,  1010,
           7929,   999,  1045,  1005,  1049,  6069,  2175,  5342,   999,  2821,
           1010,  2023,  2003,  2061, 10990,  1010,  2026,  2034,  9587, 21737,
           2099,   999,  5882,  1017,  1024,  7632,  1010,  5580,  2017,  2071,
           2272,  1012, 30522,  1024,  3531,  1010,  2272,  1999,  1012,  5882,
           1018,  1024,  7632,  1010,  2017,  1005,  2128, 30523,  1010,  2157,
           1029,  1045,  1005,  1049,  3419,  5146,  1010,  1045,  2001,  1999,
           2115,  2465,  1012, 30522,  1024,  2821,  2748,  1010,  2748,  1012,
           1012,  1012,  2

### Convert the Relation Label to Multi-hot Vector

Function to get all the relation labels

In [141]:
def get_labels(dataset):
    all_dataset_labels = set()
    for datapoint in dataset['train']:
        for relation in  [item for sublist in datapoint['relation_data']['r'] for item in sublist]:
            all_dataset_labels.add(relation)
    return list(all_dataset_labels)

## Create PyTorch Dataset

In [142]:
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import MultiLabelBinarizer

In [143]:
# Initialize the binarizer
mlb = MultiLabelBinarizer(classes=get_labels(dataset))

In [144]:
# Create PyTorch Dataset
class DialogREDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        token, entity_positions = preprocess_sample(item)
        label = torch.tensor(mlb.fit_transform([item['relation']])[0])
        return token.input_ids.squeeze(0), token.attention_mask.squeeze(0), entity_positions, label

In [145]:
train_dataset = DialogREDataset(reformatted_dataset['train'], tokenizer)
test_dataset = DialogREDataset(reformatted_dataset['test'], tokenizer)
validation_dataset = DialogREDataset(reformatted_dataset['validation'], tokenizer)

In [146]:
#TODO
#   - Replace entities in dialog with [SUBJ] [OBJ]
#       > [CLS]d*[SEP]e1[SEP]e2[SEP] where d* is as above
#   - Get positional indices of entity1 (SUBJ token) and entity2 (OBJ token) in dialog