<a href="https://colab.research.google.com/github/Mozzer2310/text-mining-cwk/blob/sam-experiments/DL_Relation_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Get the Dataset
We need to install the `datasets` module to download the [DialogRE](https://huggingface.co/datasets/dataset-org/dialog_re) dataset.

In [1]:
! pip install datasets -q

Then we can download the dataset.

In [2]:
from datasets import load_dataset

dataset = load_dataset("dataset-org/dialog_re", download_mode="force_redownload", trust_remote_code=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


dialog_re.py:   0%|          | 0.00/4.83k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/752k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/726k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1073 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/357 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/358 [00:00<?, ? examples/s]

Then view the Dataset and its contents.

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 1073
    })
    test: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 357
    })
    validation: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 358
    })
})

In [4]:
dataset['train'][0]

{'dialog': ["Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted!",
  'Speaker 2: There you go! Someone came!',
  "Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner!",
  'Speaker 3: Hi, glad you could come.',
  'Speaker 2: Please, come in.',
  "Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class.",
  'Speaker 2: Oh yes, yes... let me... take your coat.',
  "Speaker 4: Thanks... uh... I'm so sorry about Ross, it's...",
  'Speaker 2: At least he died doing what he loved... watching blimps.',
  'Speaker 1: Who is he?',
  'Speaker 2: Some guy, Tom Gordon.',
  "Speaker 1: I don't remember him, but then again I touched so many lives.",
  'Speaker 3: So, did you know Ross well?',
  "Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone?",
  'Speaker 3: Yes, he is. Me.',
  'S

## Preprocess the Data
1. Reformat the dataset so each sample (relation) is extracted from each item in the dataset
2. Proprocess each sample getting the tokens and the positional indices of the entities
3. Create a PyTorch dataset for the data

### Reformat the Dataset
Convert the dataset so that each item contains a singular relation.

In [5]:
def reformat_dataset(dataset, add_triggers=True):
    reformatted_dataset = []

    for item in dataset:
        dialog = item['dialog']
        relation_data = item['relation_data']

        # Join the dialog into a single string
        all_dialog = ' '.join(dialog)

        samples = []
        for x, y, r, t in zip(relation_data['x'], relation_data['y'], relation_data['r'], relation_data['t']):
            sample = {'dialog': all_dialog, 'x': x, 'y': y, 'relation': r}
            if add_triggers:
                sample['trigger'] = t
            samples.append(sample)

        reformatted_dataset.extend(samples)

    return reformatted_dataset

In [6]:
reformatted_dataset = {}
for split in dataset.keys():
    reformatted_dataset[split] = reformat_dataset(dataset[split], add_triggers=False)

In [7]:
print(reformatted_dataset['train'][0])

{'dialog': "Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted! Speaker 2: There you go! Someone came! Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner! Speaker 3: Hi, glad you could come. Speaker 2: Please, come in. Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class. Speaker 2: Oh yes, yes... let me... take your coat. Speaker 4: Thanks... uh... I'm so sorry about Ross, it's... Speaker 2: At least he died doing what he loved... watching blimps. Speaker 1: Who is he? Speaker 2: Some guy, Tom Gordon. Speaker 1: I don't remember him, but then again I touched so many lives. Speaker 3: So, did you know Ross well? Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone? Speaker 3: Yes, he is. Me. Speaker 4: What? You... You... Oh! Can I ask you a personal question? Ho-how 

### Preprocess each Sample

In [8]:
import re
import torch
from transformers import AutoTokenizer

In [9]:
MODEL_NAME = "FacebookAI/roberta-base"
#MODEL_NAME = "bert-base-uncased"

In [10]:
sequence_length = 1024
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding="max_length", max_length=sequence_length)
SUBJ_TOKEN = '[SUBJ]'
OBJ_TOKEN = '[OBJ]'
special_tokens_dict = {'additional_special_tokens': [SUBJ_TOKEN, OBJ_TOKEN]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [11]:
def preprocess_sample(sample, max_positions=20):
    SEP_TOKEN = '[SEP]'
    dialog = sample['dialog']
    x = sample['x']
    y = sample['y']

    dialog1 = dialog.replace(x, '[SUBJ]')
    dialog2 = dialog1.replace(y, '[OBJ]')

    text = f"{dialog2} {SEP_TOKEN} {x} {SEP_TOKEN} {y}"
    if 'trigger' in sample:
        trigger = ', '.join(sample['trigger'])
        text += f" {SEP_TOKEN} {trigger}"

    tokens = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt", max_length=sequence_length)

    # Find entity positions
    words = tokenizer.tokenize(text)
    e1_positions = [i for i, x in enumerate(words) if x == SUBJ_TOKEN]
    e2_positions = [i for i, x in enumerate(words) if x == OBJ_TOKEN]

    # Pad with zeroes to a fixed length
    e1_positions += [0] * (max_positions - len(e1_positions))
    e2_positions += [0] * (max_positions - len(e2_positions))

    return {
        'tokens': tokens,
        'e1_positions': torch.tensor(e1_positions),
        'e2_positions': torch.tensor(e2_positions)
    }

    # TODO: replace instances of x and y with [SUBJ] and [OBJ] tokens
    # TODO: append [SEP] x [SEP] y {[SEP] trigger}
    # TODO: tokenize with relevant tokenizer add special tokens first
    # TODO: get indices of [SUBJ] and [OBJ] in tokenized dialog

In [12]:
preprocess_sample(reformatted_dataset['train'][0])

{'tokens': {'input_ids': tensor([[    0, 29235,  4218,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]])},
 'e1_positions': tensor([ 36,  82, 113, 145, 169, 279, 309,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0]),
 'e2_positions': tensor([97,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0])}

### Convert the Relation Label to Multi-hot Vector

Function to get all the relation labels

In [13]:
def get_labels(dataset):
    all_dataset_labels = set()
    for datapoint in dataset['train']:
        for relation in  [item for sublist in datapoint['relation_data']['r'] for item in sublist]:
            all_dataset_labels.add(relation)
    return list(all_dataset_labels)

## Create PyTorch Dataset

In [14]:
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import MultiLabelBinarizer

In [15]:
# Initialize the binarizer
class_labels = get_labels(dataset)
mlb = MultiLabelBinarizer(classes=class_labels)
# Fit the binarizer to the training labels
mlb.fit([item['relation'] for item in reformatted_dataset['train']])

In [16]:
# Create PyTorch Dataset
class DialogREDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        processed_sample = preprocess_sample(item)
        label = torch.tensor(mlb.fit_transform([item['relation']])[0], dtype=torch.float32)
        return {
            'input_ids': processed_sample['tokens'].input_ids.squeeze(0),
            'attention_mask': processed_sample['tokens'].attention_mask.squeeze(0),
            'e1_positions': processed_sample['e1_positions'],
            'e2_positions': processed_sample['e2_positions'],
            'labels': label
        }

In [17]:
train_dataset = DialogREDataset(reformatted_dataset['train'], tokenizer)
test_dataset = DialogREDataset(reformatted_dataset['test'], tokenizer)
validation_dataset = DialogREDataset(reformatted_dataset['validation'], tokenizer)

In [18]:
train_dataset[0]

{'input_ids': tensor([    0, 29235,  4218,  ...,     1,     1,     1]),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0]),
 'e1_positions': tensor([ 36,  82, 113, 145, 169, 279, 309,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0]),
 'e2_positions': tensor([97,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]),
 'labels': tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

In [19]:
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig # Using BertModel to fetch the base model without the extra head

In [20]:
class MultiLabelBERT(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        config = RobertaConfig.from_pretrained("roberta-base", max_position_embeddings=sequence_length)
        self.bert = RobertaModel(config)
        input_hidden_size = self.bert.config.hidden_size

        # Resize the token embeddings to accommodate the new tokens
        self.bert.resize_token_embeddings(len(tokenizer))

        # Add our own classification head for predicting multiple labels
        self.classifier = nn.Sequential(
            nn.Linear(input_hidden_size * 2, sequence_length), # Multiply input size by 2 as we have two entities
            nn.ReLU(),
            nn.Linear(sequence_length, len(class_labels)), # There are 35 relations in the dataset
            nn.Sigmoid() # Sigmoid activation is used to output probabilities for each label
        )

    def avg_embedding(self, bert_output, entity_positions):
        # Filter out zero positions
        entity_positions_filtered = entity_positions[entity_positions != 0]

        #
        if len(entity_positions_filtered) == 0:
            return torch.zeros((bert_output.shape[-1]), dtype=torch.float32).to(bert_output.device)

        entity_embs = bert_output[:, entity_positions_filtered, :]
        return torch.mean(entity_embs, dim=1)  # Mean over all occurrences

    def forward(self, input_ids, attention_mask, e1_positions, e2_positions):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        # Extract entity representations and apply averaging for multiple occurences
        e1_embedding = self.avg_embedding(bert_output, e1_positions)
        e2_embedding = self.avg_embedding(bert_output, e2_positions)

        # Concatenate entity embeddings
        combined_embedding = torch.cat([e1_embedding, e2_embedding], dim=-1)

        # Feed into the neural network classifier
        logits = self.classifier(combined_embedding)

        # Return relation probabilities for each label
        return logits

In [21]:
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm

# Create dataset and model
model = MultiLabelBERT(MODEL_NAME)

# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Training loop
num_epochs = 3

# Set up optimizer, scheduler and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
loss_fn = nn.BCELoss()  # Binary cross-entropy loss is used for multi-label problems

# Use CUDA
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)

    for batch in progress_bar:
        # Load batches onto device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        e1_positions = batch['e1_positions'].to(device)
        e2_positions = batch['e2_positions'].to(device)
        labels = batch['labels'].to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        logits = model(input_ids, attention_mask, e1_positions, e2_positions)

        # Compute loss
        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()
        scheduler.step()

        # Update progress bar
        progress_bar.set_postfix({'loss': loss.item()})

    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')

# Save the model
torch.save(model.state_dict(), 'model0.pth')

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
Epoch 1/3:   0%|          | 2/750 [00:04<28:21,  2.28s/it, loss=0.634]Token indices sequence length is longer than the specified maximum sequence length for this model (551 > 512). Running this sequence through the model will result in indexing errors


KeyboardInterrupt: 