<a href="https://colab.research.google.com/github/Mozzer2310/text-mining-cwk/blob/sam-experiments/DL_Relation_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Get the Dataset
We need to install the `datasets` module to download the [DialogRE](https://huggingface.co/datasets/dataset-org/dialog_re) dataset.

In [1]:
! pip install datasets -q

Then we can download the dataset.

In [2]:
from datasets import load_dataset

dataset = load_dataset("dataset-org/dialog_re", download_mode="force_redownload", trust_remote_code=True)

dialog_re.py:   0%|          | 0.00/4.83k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


README.md:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/752k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/726k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1073 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/357 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/358 [00:00<?, ? examples/s]

Then view the Dataset and its contents.

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 1073
    })
    test: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 357
    })
    validation: Dataset({
        features: ['dialog', 'relation_data'],
        num_rows: 358
    })
})

In [4]:
dataset['train'][0]

{'dialog': ["Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted!",
  'Speaker 2: There you go! Someone came!',
  "Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner!",
  'Speaker 3: Hi, glad you could come.',
  'Speaker 2: Please, come in.',
  "Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class.",
  'Speaker 2: Oh yes, yes... let me... take your coat.',
  "Speaker 4: Thanks... uh... I'm so sorry about Ross, it's...",
  'Speaker 2: At least he died doing what he loved... watching blimps.',
  'Speaker 1: Who is he?',
  'Speaker 2: Some guy, Tom Gordon.',
  "Speaker 1: I don't remember him, but then again I touched so many lives.",
  'Speaker 3: So, did you know Ross well?',
  "Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone?",
  'Speaker 3: Yes, he is. Me.',
  'S

## Defining Constants
Constant definitions for the model.

In [5]:
# MODEL_NAME = "FacebookAI/roberta-base"
MODEL_NAME = "bert-base-uncased"
SEQ_LENGTH = 512

## Preprocess the Data
1. Reformat the dataset so each sample (relation) is extracted from each item in the dataset
2. Proprocess each sample getting the tokens and the positional indices of the entities
3. Create a PyTorch dataset for the data

### Reformat the Dataset
Convert the dataset so that each item contains a singular relation.

In [6]:
def reformat_dataset(dataset, add_triggers=True):
    reformatted_dataset = []

    for item in dataset:
        dialog = item['dialog']
        relation_data = item['relation_data']

        # Join the dialog into a single string
        all_dialog = ' '.join(dialog)

        samples = []
        for x, y, r, t in zip(relation_data['x'], relation_data['y'], relation_data['r'], relation_data['t']):
            sample = {'dialog': all_dialog, 'x': x, 'y': y, 'relation': r}
            if add_triggers:
                sample['trigger'] = t
            samples.append(sample)

        reformatted_dataset.extend(samples)

    return reformatted_dataset

In [7]:
reformatted_dataset = {}
for split in dataset.keys():
    reformatted_dataset[split] = reformat_dataset(dataset[split], add_triggers=False)

In [8]:
print(reformatted_dataset['train'][0])

{'dialog': "Speaker 1: It's been an hour and not one of my classmates has shown up! I tell you, when I actually die some people are gonna get seriously haunted! Speaker 2: There you go! Someone came! Speaker 1: Ok, ok! I'm gonna go hide! Oh, this is so exciting, my first mourner! Speaker 3: Hi, glad you could come. Speaker 2: Please, come in. Speaker 4: Hi, you're Chandler Bing, right? I'm Tom Gordon, I was in your class. Speaker 2: Oh yes, yes... let me... take your coat. Speaker 4: Thanks... uh... I'm so sorry about Ross, it's... Speaker 2: At least he died doing what he loved... watching blimps. Speaker 1: Who is he? Speaker 2: Some guy, Tom Gordon. Speaker 1: I don't remember him, but then again I touched so many lives. Speaker 3: So, did you know Ross well? Speaker 4: Oh, actually I barely knew him. Yeah, I came because I heard Chandler's news. D'you know if he's seeing anyone? Speaker 3: Yes, he is. Me. Speaker 4: What? You... You... Oh! Can I ask you a personal question? Ho-how 

### Preprocess each Sample

In [9]:
import re
import torch
from transformers import AutoTokenizer

In [10]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding="max_length", max_length=SEQ_LENGTH)
SUBJ_TOKEN = '[SUBJ]'
OBJ_TOKEN = '[OBJ]'
special_tokens_dict = {'additional_special_tokens': [SUBJ_TOKEN, OBJ_TOKEN]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [11]:
def truncate_dialog_with_entities(tokenizer, text):
    """
    Truncates the dialog text while ensuring that the subject and object tokens are included.

    Args:
        tokenizer: The tokenizer object.
        text: The full text containing the dialog and entities.

    Returns:
         A tuple containing:
          - The truncated text.
          - List of tokens in the truncated text
    """

    parts = text.split('[SEP]')
    dialog = parts[0].strip()
    entities_and_triggers = '[SEP]' + '[SEP]'.join(parts[1:])

    tokens_full = tokenizer.tokenize(dialog)
    tokens_entities_and_triggers = tokenizer.tokenize(entities_and_triggers)

    e1_positions_full = [i for i, token in enumerate(tokens_full) if token == SUBJ_TOKEN]
    e2_positions_full = [i for i, token in enumerate(tokens_full) if token == OBJ_TOKEN]

    if len(tokens_full) + len(tokens_entities_and_triggers) <= SEQ_LENGTH:
        return text, tokenizer.tokenize(text)  # No truncation needed

    first_entity_index = min(e1_positions_full + e2_positions_full) if (e1_positions_full and e2_positions_full) else 0
    last_entity_index = max(e1_positions_full + e2_positions_full) if (e1_positions_full and e2_positions_full) else len(tokens_full)

    available_length_for_dialog = SEQ_LENGTH - len(tokens_entities_and_triggers)

    start_index = max(0, first_entity_index - (available_length_for_dialog // 4))
    end_index = min(len(tokens_full), last_entity_index + (available_length_for_dialog // 4))

    if end_index - start_index > available_length_for_dialog:
        end_index = min(len(tokens_full), start_index + available_length_for_dialog)

    if end_index - start_index > available_length_for_dialog:
        start_index = max(0, end_index - available_length_for_dialog)


    truncated_tokens_dialog = tokens_full[start_index:end_index]
    truncated_dialog = tokenizer.convert_tokens_to_string(truncated_tokens_dialog)

    truncated_text = f"{truncated_dialog} {entities_and_triggers.strip()}"

    return truncated_text, tokenizer.tokenize(truncated_text)

In [12]:
def preprocess_sample(sample, max_positions=20):
    SEP_TOKEN = '[SEP]'
    dialog = sample['dialog']
    x = sample['x']
    y = sample['y']

    dialog1 = dialog.replace(x, '[SUBJ]')
    dialog2 = dialog1.replace(y, '[OBJ]')

    text = f"{dialog2} {SEP_TOKEN} {x} {SEP_TOKEN} {y}"
    if 'trigger' in sample:
        trigger = ', '.join(sample['trigger'])
        text += f" {SEP_TOKEN} {trigger}"

    truncated_text, _ = truncate_dialog_with_entities(tokenizer, text)

    # Tokenize the (potentially truncated) text
    tokens = tokenizer(truncated_text, padding="max_length", truncation=True, max_length=SEQ_LENGTH, return_tensors="pt")

    # Find entity positions within the (potentially truncated) text
    words = tokenizer.convert_ids_to_tokens(tokens.input_ids.squeeze())
    e1_positions = [i for i, x in enumerate(words) if x == SUBJ_TOKEN]
    e2_positions = [i for i, x in enumerate(words) if x == OBJ_TOKEN]

    # Pad with -1 to a fixed length
    e1_positions += [-1] * (max_positions - len(e1_positions))
    e2_positions += [-1] * (max_positions - len(e2_positions))

    return {
        'tokens': tokens,
        'e1_positions': torch.tensor(e1_positions),
        'e2_positions': torch.tensor(e2_positions)
    }

In [13]:
preprocess_sample(reformatted_dataset['train'][0])

{'tokens': {'input_ids': tensor([[  101,  5882,  1015,  1024,  2009,  1005,  1055,  2042,  2019,  3178,
           1998,  2025,  2028,  1997,  2026, 19846,  2038,  3491,  2039,   999,
           1045,  2425,  2017,  1010,  2043,  1045,  2941,  3280,  2070,  2111,
           2024,  6069,  2131,  5667, 11171,   999, 30522,  1024,  2045,  2017,
           2175,   999,  2619,  2234,   999,  5882,  1015,  1024,  7929,  1010,
           7929,   999,  1045,  1005,  1049,  6069,  2175,  5342,   999,  2821,
           1010,  2023,  2003,  2061, 10990,  1010,  2026,  2034,  9587, 21737,
           2099,   999,  5882,  1017,  1024,  7632,  1010,  5580,  2017,  2071,
           2272,  1012, 30522,  1024,  3531,  1010,  2272,  1999,  1012,  5882,
           1018,  1024,  7632,  1010,  2017,  1005,  2128, 30523,  1010,  2157,
           1029,  1045,  1005,  1049,  3419,  5146,  1010,  1045,  2001,  1999,
           2115,  2465,  1012, 30522,  1024,  2821,  2748,  1010,  2748,  1012,
           1012,

### Convert the Relation Label to Multi-hot Vector

Function to get all the relation labels

In [14]:
def get_labels(dataset):
    all_dataset_labels = set()
    for datapoint in dataset['train']:
        for relation in  [item for sublist in datapoint['relation_data']['r'] for item in sublist]:
            all_dataset_labels.add(relation)
    return list(all_dataset_labels)

## Create PyTorch Dataset

In [15]:
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import MultiLabelBinarizer

In [16]:
# Initialize the binarizer
class_labels = get_labels(dataset)
mlb = MultiLabelBinarizer(classes=class_labels)
# Fit the binarizer to the training labels
mlb.fit([item['relation'] for item in reformatted_dataset['train']])

In [17]:
# Create PyTorch Dataset
class DialogREDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        processed_sample = preprocess_sample(item)
        label = torch.tensor(mlb.fit_transform([item['relation']])[0], dtype=torch.float32)
        return {
            'input_ids': processed_sample['tokens'].input_ids.squeeze(0),
            'attention_mask': processed_sample['tokens'].attention_mask.squeeze(0),
            'e1_positions': processed_sample['e1_positions'],
            'e2_positions': processed_sample['e2_positions'],
            'labels': label
        }

In [18]:
train_dataset = DialogREDataset(reformatted_dataset['train'], tokenizer)
test_dataset = DialogREDataset(reformatted_dataset['test'], tokenizer)
validation_dataset = DialogREDataset(reformatted_dataset['validation'], tokenizer)

## Create BERT Model for multi-label classification task

In [19]:
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig, BertModel # Using BertModel to fetch the base model without the extra head

In [20]:
class MultiLabelBERT(nn.Module):
    def __init__(self, model_name):
        super(MultiLabelBERT, self).__init__()
        if model_name == "FacebookAI/roberta-base":
            config = RobertaConfig.from_pretrained("roberta-base", max_position_embeddings=SEQ_LENGTH)
            self.bert = RobertaModel(config)
        else:
            self.bert = BertModel.from_pretrained(model_name)

        self.loss_fn = nn.BCEWithLogitsLoss()  # Multi-label loss function
        input_hidden_size = self.bert.config.hidden_size

        # Resize the token embeddings to accommodate the new tokens
        self.bert.resize_token_embeddings(len(tokenizer))

        # Add our own classification head for predicting multiple labels
        self.dropout = nn.Dropout(0.1)  # Dropout layer to reduce overfitting
        self.classifier = nn.Linear(input_hidden_size * 3, len(class_labels)) # x3 hidden size as we concat [CLS], e1 and e2 embeddings

    def avg_embedding(self, bert_output, entity_positions):
        # Filter out zero positions
        entity_positions_filtered = entity_positions[entity_positions != 0]

        # Fetch all entity embeddings
        entity_embeddings = bert_output[:, entity_positions_filtered, :]
        return torch.mean(entity_embeddings, dim=1)  # Mean over all occurrences

    def forward(self, input_ids, attention_mask, e1_positions, e2_positions, labels=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        # Extract entity representations and apply averaging for multiple occurences
        e1_embedding = self.avg_embedding(bert_output, e1_positions)
        e2_embedding = self.avg_embedding(bert_output, e2_positions)

        # Concatenate entity embeddings and [CLS] token (provides information about the dialog)
        cls_embedding = bert_output[:, 0, :]
        combined_embedding = torch.cat([cls_embedding, e1_embedding, e2_embedding], dim=-1)

        # Feed into the fully connected layer
        logits = self.classifier(combined_embedding)

        # Return dict with loss and logits (loss=None for inferencing)
        if labels is not None:  # Compute loss during training as labels are provided
          loss = self.loss_fn(logits, labels.float())
          return {"loss": loss, "logits": logits}
        return {"loss": None, "logits": logits}

## Train and save the model

In [21]:
from transformers import Trainer, TrainingArguments
#from google.colab import drive
import numpy as np

#drive.mount('/content/drive')
OUTPUT_DIR = 'C:/Users/tonyl/Documents/dialogRE/model'
LOGS_DIR = 'C:/Users/tonyl/Documents/dialogRE/model/logs'

# Define model
model = MultiLabelBERT(MODEL_NAME)

training_args = TrainingArguments(
    optim="adamw_torch",
    output_dir=OUTPUT_DIR,           # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=8,   # batch size per device during training
    per_device_eval_batch_size=8,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=LOGS_DIR,            # directory for storing logs
    eval_strategy="epoch",           # print evaluation after every epoch
    save_strategy="no",              # do not save checkpoints
    learning_rate=5e-5,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=validation_dataset      # evaluation dataset
)

trainer.train()
trainer.save_model(OUTPUT_DIR)

Epoch,Training Loss,Validation Loss
1,0.1754,0.090762
2,0.0672,0.061486
3,0.05,0.056491
4,0.0345,0.055599
5,0.0263,0.055426


Token indices sequence length is longer than the specified maximum sequence length for this model (542 > 512). Running this sequence through the model will result in indexing errors


## Inferencing

In [33]:
from safetensors.torch import load_model, save_model # Model is saved in .safetensors format

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
saved_model = MultiLabelBERT(MODEL_NAME).to(device)
load_model(saved_model, OUTPUT_DIR+'/model.safetensors') # Load the model

def predict_relations(sample_data, model, tokenizer, threshold=0.5):
  processed_sample = preprocess_sample(sample_data)
  input_ids = processed_sample['tokens'].input_ids.to(device)
  attention_mask = processed_sample['tokens'].attention_mask.to(device)
  e1_positions = processed_sample['e1_positions'].to(device)
  e2_positions = processed_sample['e2_positions'].to(device)

  with torch.no_grad():
    output_dict = model(input_ids, attention_mask, e1_positions, e2_positions)

  # Apply sigmoid to the logits
  relation_probabilities = torch.sigmoid(output_dict['logits'])

  # Apply threshold to get predicted relations
  relation_indices = np.where(relation_probabilities.cpu().numpy() > threshold)[1]

  # Return list of relation indices
  return relation_indices

sample_data = reformatted_dataset['test'][126]
for key in sample_data:
  print(key, sample_data[key])

sample_label = torch.tensor(mlb.fit_transform([sample_data['relation']])[0], dtype=torch.float32)
relation_indices = predict_relations(sample_data, saved_model, tokenizer, threshold=0.5)
predicted_relations = [class_labels[i] for i in relation_indices]
print(predicted_relations)

dialog Speaker 1: Presenting the award for Favorite Supporting Actress is Joey Tribbiani from Days of Our Lives. Speaker 2: Any one of the brilliant actresses nominated for this award tonight deserves to take it home. Unfortunately only one can. The nominees for Best Supporting Actress are from Passions Erin Goff. From One Life to Live Mary Loren Bishop, from All My Children Sarah Mchann, and from Days of Our Lives Jessica Ashley. And the winner is……Jessica Ashley from Days of Our Lives. Uh, unfortunately Jessica couldn’t be with us tonight so I’ll be accepting this award on her behalf. And I’m sure that Jessica would like to thank my parents who always believed in me. She’d also like to thank my friends, Chandler, Monica, Ross, Phoebe, and Rachel who’s sittin’ right there!
x Monica
y Speaker 2
relation ['per:friends']
['per:friends']
