<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LUKE/Supervised_relation_extraction_with_LukeForEntityPairClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

The author of LUKE has fine-tuned this model on the [TACRED](https://nlp.stanford.edu/projects/tacred/) dataset, an important supervised relation extraction dataset by Stanford University, and obtains state-of-the-art results with it. 

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [1]:
!pip install -q transformers 

[K     |████████████████████████████████| 2.3MB 2.9MB/s 
[K     |████████████████████████████████| 901kB 17.7MB/s 
[K     |████████████████████████████████| 3.3MB 24.2MB/s 
[?25h

In [2]:
!pip install -q pytorch-lightning wandb

[K     |████████████████████████████████| 808kB 2.9MB/s 
[K     |████████████████████████████████| 1.8MB 16.1MB/s 
[K     |████████████████████████████████| 645kB 20.2MB/s 
[K     |████████████████████████████████| 829kB 19.6MB/s 
[K     |████████████████████████████████| 112kB 29.4MB/s 
[K     |████████████████████████████████| 10.6MB 24.0MB/s 
[K     |████████████████████████████████| 276kB 37.3MB/s 
[K     |████████████████████████████████| 102kB 10.0MB/s 
[K     |████████████████████████████████| 174kB 38.5MB/s 
[K     |████████████████████████████████| 133kB 40.7MB/s 
[K     |████████████████████████████████| 1.3MB 35.2MB/s 
[K     |████████████████████████████████| 71kB 8.9MB/s 
[K     |████████████████████████████████| 143kB 30.2MB/s 
[K     |████████████████████████████████| 296kB 35.4MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ..

## Read in data

Let's download the data from the web, hosted on Dropbox.

In [3]:
import requests, zipfile, io

def download_data():
    url = "https://www.dropbox.com/s/izi2x4sjohpzoot/relation_extraction_dataset.zip?dl=1"
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall()

download_data()

Each row in the dataframe consists of a news article, and a sentence in which a certain relationship was found (just as "invested_in", or "founded_by"). There were some patterns used to gather the data, so it might contain some noise. 

In [5]:
import pandas as pd

df = pd.read_pickle("relation_extraction_dataset.pkl")
df.reset_index(drop=True, inplace=True)
df.head()

Unnamed: 0,end_idx,entities,entity_spans,match,original_article,sentence,start_idx,string_id
0,1024,"[Lilium, Baillie Gifford]","[[3, 9], [151, 166]]",raising $35,Happy Friday!\n\nWe sincerely hope you and you...,"3) Lilium, a German startup that’s making an a...",1013,invested_in
1,1762,"[Facebook ’s, Giphy]","[[92, 102], [148, 153]]",acquisition,Happy Friday!\n\nWe sincerely hope you and you...,"Meanwhile, the UK’s watchdog on Friday announc...",1751,acquired_by
2,2784,"[Global-e, Vitruvian Partners]","[[27, 35], [94, 112]]",raised $60,Happy Friday!\n\nWe sincerely hope you and you...,Israeli e-commerce startup Global-e has raised...,2774,invested_in
3,680,"[Joris Van Der Gucht, Silverfin]","[[0, 19], [35, 44]]",founder,Hg is a leading investor in tax and accounting...,"Joris Van Der Gucht, co-founder at Silverfin c...",673,founded_by
4,2070,"[Tim Vandecasteele, Silverfin]","[[0, 17], [71, 80]]",founder,Hg is a leading investor in tax and accounting...,"Tim Vandecasteele, co-founder added: ""We want ...",2063,founded_by


Let's create 2 dictionaries, one that maps each label to a unique integer, and one that does it the other way around.

In [None]:
id2label = dict()
for idx, label in enumerate(df.string_id.value_counts().index):
  id2label[idx] = label

As we can see, there are 7 labels (7 unique relationships):

In [None]:
id2label

{0: 'founded_by',
 1: 'acquired_by',
 2: 'invested_in',
 3: 'CEO_of',
 4: 'subsidiary_of',
 5: 'partners_with',
 6: 'owned_by'}

In [None]:
label2id = {v:k for k,v in id2label.items()}
label2id

{'CEO_of': 3,
 'acquired_by': 1,
 'founded_by': 0,
 'invested_in': 2,
 'owned_by': 6,
 'partners_with': 5,
 'subsidiary_of': 4}

In [None]:
df.shape

(12031, 8)

## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


In [None]:
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        sentence = item.sentence
        entity_spans = [tuple(x) for x in item.entity_spans]

        encoding = tokenizer(sentence, entity_spans=entity_spans, padding="max_length", truncation=True, return_tensors="pt")

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        encoding["label"] = torch.tensor(label2id[item.string_id])

        return encoding

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, shuffle=False)

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
valid_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

In [None]:
train_dataset[0].keys()

dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'])

Let's define the corresponding dataloaders (which allow us to iterate over the elements of the dataset):

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2)
test_dataloader = DataLoader(test_dataset, batch_size=2)

Let's verify an example of a batch:

In [None]:
batch = next(iter(train_dataloader))
tokenizer.decode(batch["input_ids"][1])

'<s>The <ent>  Data Security Council of India <ent>, in partnership with <ent2>  PayPal <ent2>, said in an August report that the number of people shopping online in India grew at 73 per cent for big cities and at 400 per cent in smaller towns.\n\n</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [None]:
id2label[batch["label"][1].item()]

'partners_with'

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [None]:
from transformers import LukeForEntityPairClassification, AdamW
import pytorch_lightning as pl

class LUKE(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", num_labels=len(label2id))

    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label']
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss() # multi-class classification
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=5e-5)
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return valid_dataloader

    def test_dataloader(self):
        return test_dataloader

Let's verify a forward pass on a batch:

In [None]:
batch = next(iter(valid_dataloader))
labels = batch["label"]
batch.keys()

dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'])

In [None]:
batch["input_ids"].shape

torch.Size([2, 512])

In [None]:
model = LUKE()
del batch["label"]
outputs = model(**batch)

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The initial loss should be around -ln(1/number of classes) = -ln(1/7) = 1.95:

In [None]:
criterion = torch.nn.CrossEntropyLoss()

initial_loss = criterion(outputs.logits, labels)
print("Initial loss:", initial_loss)

Initial loss: tensor(1.9616, grad_fn=<NllLossBackward>)


## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [None]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mnielsrogge[0m (use `wandb login --relogin` to force relogin)


True

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

wandb_logger = WandbLogger(name='luke-first-run-12000-articles-bis', project='LUKE')
# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=2,
    strict=False,
    verbose=False,
    mode='min'
)

trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | LukeForEntityPairClassification | 274 M 
----------------------------------------------------------
274 M     Trainable params
0         Non-trainable params
274 M     Total params
1,098.052 Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



In [None]:
trainer.test()

## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

loaded_model.model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)

predictions_total = []
labels_total = []
for batch in tqdm(test_dataloader):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to(device)

    # forward pass
    outputs = loaded_model.model(**batch)
    logits = outputs.logits
    predictions = logits.argmax(-1)
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())

HBox(children=(FloatProgress(value=0.0, max=1204.0), HTML(value='')))




In [None]:
print("Accuracy on test set:", accuracy_score(labels_total, predictions_total))

Accuracy on test set: 0.9713336103032821


## Inference

Here we test the trained model on a new, unseen sentence.

In [None]:
loaded_model = LUKE.load_from_checkpoint(checkpoint_path="/content/drive/Shareddrives/Datascouts/epoch=3-step=7699.ckpt")

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
test_df.iloc[0].sentence

'Roxanne also co-founded StartHer'

In [None]:
import torch.nn.functional as F

idx = 2
text = test_df.iloc[idx].sentence
entity_spans = test_df.iloc[idx].entity_spans  # character-based entity spans 
entity_spans = [tuple(x) for x in entity_spans]

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")

outputs = loaded_model.model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Sentence:", text)
print("Ground truth label:", test_df.iloc[idx].string_id)
print("Predicted class idx:", id2label[predicted_class_idx])
print("Confidence:", F.softmax(logits, -1).max().item())

Sentence: One Twitter user suggested that Sun ask Buffett if he regrets not having invested early on in Apple Microsoft Amazon Google Netflix and
Ground truth label: invested_in
Predicted class idx: invested_in
Confidence: 0.9971879124641418
