<a href="https://colab.research.google.com/github/Mozzer2310/text-mining-cwk/blob/sam-experiments/DL_Relation_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Get the Dataset
We need to install the `datasets` module to download the [DialogRE](https://huggingface.co/datasets/dataset-org/dialog_re) dataset.

In [None]:
! pip install datasets -q

Then we can download the dataset.

In [None]:
from datasets import load_dataset

dataset = load_dataset("dataset-org/dialog_re", download_mode="force_redownload", trust_remote_code=True)

dialog_re.py:   0%|          | 0.00/4.83k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


README.md:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/752k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/726k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1073 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/357 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/358 [00:00<?, ? examples/s]

Then view the Dataset and its contents.

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 1073
    })
    test: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 357
    })
    validation: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 358
    })
})

In [None]:
dataset['train'][0]

{'dialog': ["Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted!",
  'Speaker 2: There you go! Someone came!',
  "Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner!",
  'Speaker 3: Hi, glad you could come.',
  'Speaker 2: Please, come in.',
  "Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class.",
  'Speaker 2: Oh yes, yes... let me... take your coat.',
  "Speaker 4: Thanks... uh... I'm so sorry about Ross, it's...",
  'Speaker 2: At least he died doing what he loved... watching blimps.',
  'Speaker 1: Who is he?',
  'Speaker 2: Some guy, Tom Gordon.',
  "Speaker 1: I don't remember him, but then again I touched so many lives.",
  'Speaker 3: So, did you know Ross well?',
  "Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone?",
  'Speaker 3: Yes, he is. Me.',
  'S

## Defining Constants
Constant definitions for the model.

In [None]:
MODEL_NAME = "bert-base-uncased"
SEQ_LENGTH = 512

## Reformat the Dataset
From the output above in the notebook it can be seen that a single item in the dataset actually contains multiple entities and their relations. So we reformat the dataset so that each item in the dataset is a singular sample. This also allows us to keep only the parts we require.

In [None]:
def reformat_dataset(dataset, add_triggers=True):
    """
    Truncates the dialog text while ensuring that the subject and object tokens are included.

    Args:
        dataset: The dataset object.
        add_triggers: A boolean indicating whether to add triggers to the dataset.

    Returns:
         list: A list of dictionaries containing the reformatted dataset.
    """
    reformatted_dataset = []

    for item in dataset:
        dialog = item['dialog']
        relation_data = item['relation_data']

        # Join the dialog into a single string
        all_dialog = ' '.join(dialog)

        samples = []
        # Loop over the individual relations and the parts we require
        for x, y, r, t in zip(relation_data['x'], relation_data['y'], relation_data['r'], relation_data['t']):
            sample = {'dialog': all_dialog, 'x': x, 'y': y, 'relation': r}
            if add_triggers:
                sample['trigger'] = t
            samples.append(sample)

        reformatted_dataset.extend(samples)

    return reformatted_dataset

Generate the 'train', 'test, and 'validation' parts of the dataset object into the corresponding 'train', 'test', and 'validation' lists containing the reformatted dataset.

In [None]:
reformatted_dataset = {}
for split in dataset.keys():
    reformatted_dataset[split] = reformat_dataset(dataset[split], add_triggers=False)

*Example: The first sample of the reformatted dataset train partition.*

In [None]:
reformatted_dataset['train'][0]

{'dialog': "Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted! Speaker 2: There you go! Someone came! Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner! Speaker 3: Hi, glad you could come. Speaker 2: Please, come in. Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class. Speaker 2: Oh yes, yes... let me... take your coat. Speaker 4: Thanks... uh... I'm so sorry about Ross, it's... Speaker 2: At least he died doing what he loved... watching blimps. Speaker 1: Who is he? Speaker 2: Some guy, Tom Gordon. Speaker 1: I don't remember him, but then again I touched so many lives. Speaker 3: So, did you know Ross well? Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone? Speaker 3: Yes, he is. Me. Speaker 4: What? You... You... Oh! Can I ask you a personal question? Ho-how 

## Preprocessing Steps
The dataset needs to be transformed into data that a BERT model can ingest. Basically, it needs to be tokenized and for out BERT model we need to include the entity positions in the (tokenized) text.

### BERT Input Sequence Format
```
[CLS] d* [SEP] e_1 [SEP] e_2 {[SEP] triggers} [SEP]
```
Where:
- `d*` is the *dialog* of the sample with the instances of the entities (`e_1` and `e_2`) replaced with `[SUBJ]` and `[OBJ]`.
- `e_1` is the first entity of the relation.
- `e_2` is the second entity of the relation
- `{[SEP] triggers}` is **optional** and is the seperator token and the trigger information for the relation.

New imports for this part.

In [None]:
import re
import torch
from transformers import AutoTokenizer

Define the tokenizer and add the special tokens that we defined.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding="max_length")
SUBJ_TOKEN = '[SUBJ]'
OBJ_TOKEN = '[OBJ]'
special_tokens_dict = {'additional_special_tokens': [SUBJ_TOKEN, OBJ_TOKEN]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

This function takes the text (the dialog) of a sample, tokenizes and truncates it to be 512 tokens (or less if it does not require truncation).

In order to ensure the entities are not removed the function finds the position of the entities in the fully tokenized text and then removes tokens from either end of the token list until the size is reduced to 512 tokens.

In [None]:
def truncate_dialog_with_entities(tokenizer, text):
    """
    Truncates the dialog text while ensuring that the subject and object tokens are included.

    Args:
        tokenizer: The tokenizer object.
        text: The full text containing the dialog and entities.

    Returns:
         string: The truncated text.
    """
    max_length = SEQ_LENGTH - 2 # Minus 2 for the start and end token that are added
    parts = text.split('[SEP]')
    dialog = parts[0].strip()
    entities_and_triggers = '[SEP]' + '[SEP]'.join(parts[1:])

    tokens_full = tokenizer.tokenize(dialog)
    tokens_entities_and_triggers = tokenizer.tokenize(entities_and_triggers)

    e1_positions_full = [i for i, token in enumerate(tokens_full) if token == SUBJ_TOKEN]
    e2_positions_full = [i for i, token in enumerate(tokens_full) if token == OBJ_TOKEN]

    if len(tokens_full) + len(tokens_entities_and_triggers) <= max_length:
        return text  # No truncation needed

    first_entity_index = min(e1_positions_full + e2_positions_full) if (e1_positions_full and e2_positions_full) else 0
    last_entity_index = max(e1_positions_full + e2_positions_full) if (e1_positions_full and e2_positions_full) else len(tokens_full)

    available_length_for_dialog = max_length - len(tokens_entities_and_triggers)

    start_index = max(0, first_entity_index - (available_length_for_dialog // 4))
    end_index = min(len(tokens_full), last_entity_index + (available_length_for_dialog // 4))

    if end_index - start_index > available_length_for_dialog:
        end_index = min(len(tokens_full), start_index + available_length_for_dialog)

    if end_index - start_index > available_length_for_dialog:
        start_index = max(0, end_index - available_length_for_dialog)

    truncated_tokens_dialog = tokens_full[start_index:end_index]
    truncated_dialog = tokenizer.convert_tokens_to_string(truncated_tokens_dialog)

    truncated_text = f"{truncated_dialog} {entities_and_triggers.strip()}"

    return truncated_text

This function takes a single sample and converts it to a tokenized form (as described above) along with extracting the positions of the instances of the entities in the dialog and creating a tensor for both entities storing the indices at which the entities are found in the tokenized dialog.

In [None]:
def preprocess_sample(sample, max_positions=20):
    """
    Converts a single sample to a tokenized form and extracts entity positions.

    Args:
        sample: The sample to preprocess.
        max_positions: The maximum number of positions to store for each entity.

    Returns:
         tuple: A tuple containing the potentially truncated tokenized text, entity 1 positions, and entity 2 positions.
    """
    SEP_TOKEN = '[SEP]'
    dialog = sample['dialog']
    x = sample['x']
    y = sample['y']

    dialog1 = dialog.replace(x, '[SUBJ]')
    dialog2 = dialog1.replace(y, '[OBJ]')

    text = f"{dialog2} {SEP_TOKEN} {x} {SEP_TOKEN} {y}"
    if 'trigger' in sample:
        trigger = ', '.join(sample['trigger'])
        text += f" {SEP_TOKEN} {trigger}"

    truncated_text = truncate_dialog_with_entities(tokenizer, text)

    # Tokenize the (potentially truncated) text
    tokens = tokenizer(truncated_text, padding="max_length", max_length=SEQ_LENGTH, return_tensors="pt")

    # Find entity positions within the (potentially truncated) text
    words = tokenizer.convert_ids_to_tokens(tokens.input_ids.squeeze())
    e1_positions = [i for i, x in enumerate(words) if x == SUBJ_TOKEN]
    e2_positions = [i for i, x in enumerate(words) if x == OBJ_TOKEN]

    # Pad with -1 to a fixed length
    e1_positions += [-1] * (max_positions - len(e1_positions))
    e2_positions += [-1] * (max_positions - len(e2_positions))

    return {
        'tokens': tokens,
        'e1_positions': torch.tensor(e1_positions),
        'e2_positions': torch.tensor(e2_positions)
    }

*Example: The first sample after preprocssing.*

In [None]:
preprocess_sample(reformatted_dataset['train'][0])

{'tokens': {'input_ids': tensor([[  101,  5882,  1015,  1024,  2009,  1005,  1055,  2042,  2019,  3178,
           1998,  2025,  2028,  1997,  2026, 19846,  2038,  3491,  2039,   999,
           1045,  2425,  2017,  1010,  2043,  1045,  2941,  3280,  2070,  2111,
           2024,  6069,  2131,  5667, 11171,   999, 30522,  1024,  2045,  2017,
           2175,   999,  2619,  2234,   999,  5882,  1015,  1024,  7929,  1010,
           7929,   999,  1045,  1005,  1049,  6069,  2175,  5342,   999,  2821,
           1010,  2023,  2003,  2061, 10990,  1010,  2026,  2034,  9587, 21737,
           2099,   999,  5882,  1017,  1024,  7632,  1010,  5580,  2017,  2071,
           2272,  1012, 30522,  1024,  3531,  1010,  2272,  1999,  1012,  5882,
           1018,  1024,  7632,  1010,  2017,  1005,  2128, 30523,  1010,  2157,
           1029,  1045,  1005,  1049,  3419,  5146,  1010,  1045,  2001,  1999,
           2115,  2465,  1012, 30522,  1024,  2821,  2748,  1010,  2748,  1012,
           1012,

## Create PyTorch Dataset Object


In [None]:
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import MultiLabelBinarizer

### Convert Relations to Multi-Hot Vectors
The last step to convert the dataset to the correct format for training, is to convert the relations to Multi-Hot Vectors. We require Multi-Hot vectors as some entities have mulitple relations in the dataset.

Function to get all the relation labels from the training data.

In [None]:
def get_labels():
    """
    Gets a list of all the unique relation labels.

    Returns:
         list: A list of all the unique relation labels.
    """
    relations_list = [relation for dataset in reformatted_dataset if dataset != 'test'
                      for sample in reformatted_dataset[dataset]
                      for relation in sample['relation']]
    relations_list = list(set(relations_list))
    relations_list.sort()

    return relations_list

Instantiate a `MultiLabelBinarizer` to convert each samples relations to a Multi-Hot Vector.

In [None]:
# Initialize the binarizer
class_labels = get_labels()
mlb = MultiLabelBinarizer(classes=class_labels)
# Fit the binarizer to the training labels
mlb.fit([item['relation'] for item in reformatted_dataset['train']])

### Define Dataset Class
This is a custom dataset class which inherits from PyTorch's `Dataset` class, so we can use it with `Dataloader` when training etc.

In [None]:
# Create PyTorch Dataset
class DialogREDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Args:
            idx: The index of the sample to retrieve.

        Returns:
            dict: A dictionary containing the sample.
        """
        item = self.data[idx]
        # Preprocess the sample to get the tokens and entity position
        processed_sample = preprocess_sample(item)
        # Get the Mulit-Hot Vector representation of the relation data for the sample
        label = torch.tensor(mlb.fit_transform([item['relation']])[0], dtype=torch.float32)
        return {
            'input_ids': processed_sample['tokens'].input_ids.squeeze(0),
            'attention_mask': processed_sample['tokens'].attention_mask.squeeze(0),
            'e1_positions': processed_sample['e1_positions'],
            'e2_positions': processed_sample['e2_positions'],
            'labels': label
        }

Instantiate 3 of our custom Dataset classes for the 3 partitions of the dataset.

In [None]:
train_dataset = DialogREDataset(reformatted_dataset['train'], tokenizer)
test_dataset = DialogREDataset(reformatted_dataset['test'], tokenizer)
validation_dataset = DialogREDataset(reformatted_dataset['validation'], tokenizer)

## Create BERT Model for multi-label classification task

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel

class MultiLabelBERT(nn.Module):
    def __init__(self, model_name):
        super(MultiLabelBERT, self).__init__()

        # BERT initial setup
        self.bert = BertModel.from_pretrained(model_name)
        self.loss_fn = nn.BCEWithLogitsLoss()  # Multi-label loss function
        self.hidden_size = self.bert.config.hidden_size
        self.bert.resize_token_embeddings(len(tokenizer))

        # Define normalisation layers
        self.dropout = nn.Dropout(0.5)  # Dropout to reduce overfitting
        self.layer_norm = nn.LayerNorm(self.hidden_size)  # LayerNorm to be used after attention calculations

        # Classification layer
        self.classifier = nn.Linear(self.hidden_size * 3, len(class_labels))

    def get_entity_embeddings(self, bert_output, entity_positions):
        """
        Retrieve the embeddings for every occurence of an entity from BERT's last hidden state

        Args:
            bert_output: BERT's last hidden state
            entity_positions: tensor array of entity posititions

        Returns:
            entity_embeddings: a tensor containing all of the embeddings based on entity_positions
        """
        # Filter out -1 padding
        entity_positions_filtered = entity_positions[entity_positions != -1]

        # Fetch entity embeddings
        entity_embeddings = bert_output[:, entity_positions_filtered, :]
        return entity_embeddings

    def compute_query_key_attention(self, cls_embedding, entity_embeddings):
        """
        Compute attention weights for each occurence of the entity embedding
        using [CLS] as query and entity embeddings as keys/values. The weights are then
        applied to each entity embedding

        Args:
            cls_embedding: [ClS] token embedding
            entity_embeddings: array of entity embeddings

        Returns:
            weighted_embedding: a single embedding that represents a weighted average of all of the entity embeddings
        """

        # Add an extra dimension at index 1 to allow for batch matrix multiplication
        query = cls_embedding.unsqueeze(1)

        # Entity embeddings serve as both keys and values
        keys = entity_embeddings
        values = entity_embeddings

        # Compute attention scores using Q * K^T
        attention_scores = torch.bmm(query, keys.transpose(1, 2))

        # Scaled dot-product attention
        attention_scores = attention_scores / (self.hidden_size ** 0.5)

        # Apply softmax to get weights
        attention_weights = nn.functional.softmax(attention_scores, dim=-1)

        # Compute weighted sum
        weighted_embedding = torch.bmm(attention_weights, values)

        # Remove the extra dimension
        weighted_embedding = weighted_embedding.squeeze(1)

        return weighted_embedding

    def forward(self, input_ids, attention_mask, e1_positions, e2_positions, labels=None):

        # Retrieve BERT's last hidden state
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        bert_output = outputs.last_hidden_state

        # Extract [CLS] embedding
        cls_embedding = bert_output[:, 0, :]

        # Retrieve entity embeddings
        entity1_embeddings = self.get_entity_embeddings(bert_output, e1_positions)
        entity2_embeddings = self.get_entity_embeddings(bert_output, e2_positions)

        # Compute attention-weighted entity representations using [CLS] as query
        entity1_weighted = self.compute_query_key_attention(cls_embedding, entity1_embeddings)
        entity2_weighted = self.compute_query_key_attention(cls_embedding, entity2_embeddings)

        # Apply layer normalisation to entity embeddings after attention
        entity1_weighted = self.layer_norm(entity1_weighted)
        entity2_weighted = self.layer_norm(entity2_weighted)

        # Concatenate [CLS] and entity embeddings
        combined_embedding = torch.cat([cls_embedding, entity1_weighted, entity2_weighted], dim=-1)

        # Apply dropout layer
        combined_embedding = self.dropout(combined_embedding)

        # Feed into the classifier
        logits = self.classifier(combined_embedding)

        # Return dict with loss and logits (loss=None for inferencing)
        if labels is not None:  # Compute loss during training as labels are provided
            loss = self.loss_fn(logits, labels.float())
            return {"loss": loss, "logits": logits}
        return {"loss": None, "logits": logits}

## Train and save the model

In [None]:
from transformers import Trainer, TrainingArguments
#from google.colab import drive
#drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = './' # current working directory

# Define model
model = MultiLabelBERT(MODEL_NAME).to(device)

# Define training arguments
training_args = TrainingArguments(
    optim="adamw_torch",
    output_dir=OUTPUT_DIR,           # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=8,   # batch size per device during training
    per_device_eval_batch_size=8,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    eval_strategy="epoch",           # print evaluation after every epoch
    save_strategy="no",              # do not save checkpoints
    learning_rate=5e-5,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=validation_dataset      # evaluation dataset
)

trainer.train()
torch.save(model.state_dict(), OUTPUT_DIR+'model4.pth') # Save model weights only

Epoch,Training Loss,Validation Loss
1,0.1614,0.082651
2,0.066,0.058626
3,0.0482,0.052344
4,0.0319,0.05111
5,0.0233,0.050347


## Inferencing

In [None]:
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
saved_model = MultiLabelBERT(MODEL_NAME).to(device)

saved_model.load_state_dict(torch.load('model3.pth', weights_only=True))
saved_model.eval()

def predict_relations(sample_data, model, tokenizer, threshold=0.5):
  processed_sample = preprocess_sample(sample_data)
  input_ids = processed_sample['tokens'].input_ids.to(device)
  attention_mask = processed_sample['tokens'].attention_mask.to(device)
  e1_positions = processed_sample['e1_positions'].to(device)
  e2_positions = processed_sample['e2_positions'].to(device)

  with torch.no_grad():
    output_dict = model(input_ids, attention_mask, e1_positions, e2_positions)

  # Apply sigmoid to the logits
  relation_probabilities = torch.sigmoid(output_dict['logits'])

  # Apply threshold to get predicted relations
  relation_indices = np.where(relation_probabilities.cpu().numpy() > threshold)[1]

  # Return list of relation indices
  return relation_indices

sample_data = reformatted_dataset['test'][226]
for key in sample_data:
  print(key, sample_data[key])

relation_indices = predict_relations(sample_data, saved_model, tokenizer, threshold=0.3)
predicted_relations = [class_labels[i] for i in relation_indices]
print(predicted_relations)

dialog Speaker 1: The basket is totally empty! My God, the neighbors ate all the candy! Speaker 2: Well, either that or uh… Speaker 1: Joey!! Speaker 3: Yeah? Speaker 1: Did you eat all the neighbor candy?! Speaker 3: Uh well yeah, that was the plan, but by the time I got to it there was only a couple of pieces left! Speaker 4: Yeah, and they’ve been coming by all day. They love it! Speaker 1: They love my candy? Oh man!!! I’ve gotta go make more!! Speaker 3: Hey Mon, you might wanna make some more lasagna too, because something might’ve happened to a huge chunk of it. Speaker 1: Ross! The neighbors ate all my candy!! Speaker 5: Mine stole my newspaper! It’s like a crime wave!! Pheebs, you uh, you got a second. Speaker 4: Sure! Speaker 5: Yeah, ever since you uh, told me that story about that bike I-I couldn’t stop thinking about it. I mean, everyone should have a-a first bike, so… Speaker 4: Oh my God Ross!! Speaker 5: You like it? Speaker 4: I love it!! Speaker 5: Yeah? Speaker 4: Oh

# Evaluation

In [None]:
from sklearn.metrics import f1_score

# Define function for batch inference & evaluation
def evaluate_model(dataloader, model, threshold=0.5):
    """
    Evaluates model performance on a dataset by computing sample-based F1 score

    Args:
        dataloader: DataLoader object containing the test dataset
        model: trained model
        threshold: probability threshold for label selection

    Returns:
        sample_f1: sample-based F1 Score
        y_true_full: array of true labels
        y_pred_full: array of predicted labels
    """

    # Load model to device and place in evaluation mode
    model.to(device)
    model.eval()

    all_true_labels = []
    all_pred_labels = []

    with torch.no_grad():
        for batch in dataloader:

            # Fetch each item
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            e1_positions = batch['e1_positions'].to(device)
            e2_positions = batch['e2_positions'].to(device)
            labels = batch['labels'].to(device)

            # Store true labels
            all_true_labels.append(labels.cpu().numpy())

            # Calculate logits
            logits = model(input_ids, attention_mask, e1_positions, e2_positions)['logits']

            # Calculuate probabilities for each class
            probs = torch.sigmoid(logits)

            # For all relations where probabilities > threshold, mark them as 1 or 0 otehrwise
            preds = (probs > threshold).cpu().numpy().astype(int)

            all_pred_labels.append(preds)

    # Convert list of arrays into a single array
    y_true_full = np.vstack(all_true_labels)
    y_pred_full = np.vstack(all_pred_labels)

    # Compute sample-based F1 Score
    sample_f1 = f1_score(y_true_full, y_pred_full, average="samples")

    return sample_f1, y_true_full, y_pred_full

test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

f1, y_true_all, y_pred_all = evaluate_model(test_dataloader, saved_model, threshold=0.3)

print(f1)