In [None]:
from pathlib import Path
import nltk
# nltk.data.path.remove("/Users/ams557-macos/nltk_data")
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import nltk.tokenize.punkt as pkt
import pandas as pd
import pickle as pkl
from collections import defaultdict, Counter
from itertools import permutations, combinations
from functools import reduce
import numpy as np
import os,sys, io
from io import FileIO
import fnmatch
import re, string
import csv
from utils.helpers import *
from tqdm import tqdm
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from transformers import LukeForEntityPairClassification, AdamW
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
import wandb
wandb.login(key="3f4a097a574f34b0356bb664fb479ba2c4217659")

In [None]:
DATA_DIR = '../DATA/'
Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
TRAIN_DIR = DATA_DIR + 'train/training_20180910/'
Path(TRAIN_DIR).mkdir(parents=True, exist_ok=True)
TEST_DIR = DATA_DIR + 'test/test_data_Task2/'
Path(TEST_DIR).mkdir(parents=True, exist_ok=True)
RANDOM_STATE = 42

In [None]:
class InvalidAnnotationError(ValueError):
    pass

def BRATtoDFconvert(path):
    annotations = {
        'entities' : pd.DataFrame(), 
        'relations' : pd.DataFrame()
    }
    files = [file for file in os.listdir(path) if file.endswith('.ann')]
    files.sort(key=lambda f : os.path.splitext(f)[1])
    for file in files:
        annotation = read_file(path + '/' + file)
        annotations['entities'] = pd.concat([annotations['entities'],process_annotation(path + file)['entities']],ignore_index=True) 
        annotations['relations'] = pd.concat([annotations['relations'],process_annotation(path + file)['relations']],ignore_index=True)
    if not annotations['relations'].empty:
        annotations['relations'].drop(columns=['tag'],inplace=True)
        df = pd.merge(annotations['relations'],annotations['entities'][['file','tag','entity_span','entity']],left_on=['file','relation_start'],right_on=['file','tag'])
        df.drop(columns=['tag','relation_start'],inplace=True)
        df.rename(columns={'entity_span' : 'relation_start','entity' : 'start_entity', 'relation_name' : 'string_id'},inplace=True)
        df = pd.merge(df,annotations['entities'][['file','tag','entity_span','entity']],left_on=['file','relation_end'],right_on=['file','tag'])
        df.drop(columns=['tag','relation_end'],inplace=True)
        df.rename(columns={'entity_span' : 'relation_end', 'entity' :'end_entity'},inplace=True)
        df['entities'] = [[start, end] for start, end in zip(df['start_entity'], df['end_entity'])]
        df.drop(columns=['start_entity','end_entity'],inplace=True)
        df['original_article'] = [read_file(path + file + '.txt') for file in df['file']]
        df.drop(columns='file')
        df['start_idx'] = df.apply(lambda row : find_smallest_first_element(row, 'relation_start', 'relation_end'), axis=1)
        df['end_idx'] = df.apply(lambda row : find_largest_last_element(row, 'relation_start', 'relation_end'), axis=1)
        df['match'] = df.apply(lambda row : row['original_article'][row['start_idx']:row['end_idx']],axis=1)
        df['sentences'] = df.apply(lambda row : find_sentences_around_match(text=row['original_article'],begin=row['start_idx'],end=row['end_idx']),axis=1)
        df['BOS_idx'] = df.apply(lambda row : find_BOS_index(row['original_article'],row['start_idx']),axis=1)
        df['entity_spans'] = df.apply(lambda row : np.array([norm_list(row['relation_start'],row['BOS_idx']),norm_list(row['relation_end'],row['BOS_idx'])],dtype=object),axis=1)
        cols = ['end_idx', 'entities','entity_spans','match','original_article','sentences','start_idx','string_id']
        df = df[cols]
        return df.astype(object)
    return annotations['entities']

def grab_entity_info(line):
    tags = line[1].split(" ")
    entity_name = str(tags[0])
    entity_start = int(tags[1])
    entity_end = int(tags[-1])
    return pd.DataFrame({'tag' : line[0], 'entity_name' : entity_name, 'entity_span' : [np.array([entity_start, entity_end],dtype=object)], 'entity' : line[-1]},index=[0],dtype=object)

def grab_relation_info(line):
    tags = line[1].split(" ")
    assert len(tags) == 3, "Incorrect relation format"
    relation_name = tags[0]
    relation_start = tags[1].split(':')[1]
    relation_end = tags[2].split(':')[1]
    return pd.DataFrame({'tag' : line[0], 'relation_name' : relation_name, 'relation_start' : relation_start, 'relation_end' : relation_end},index=[0],dtype=object)

def process_annotation(path):
    annotations = {
        'entities' : pd.DataFrame(), 
        'relations' : pd.DataFrame()
    }
    with open(path,'r') as file:
        annotation = file.readlines()
    for line in annotation:
        line = line.strip()
        annotations['entities']['file'] = os.path.split(path)[1].replace(".ann","")
        if line == "" or line.startswith("#"):
            continue
        if "\t" not in line:
            InvalidAnnotationError("Line chunks in ANN files must be separated by tabs (See BRAT Guidelines).")
        line = line.split("\t")
        if line[0][0] == 'T':
            # print(f"{os.path.split(path)[1].replace(".ann","")}")
            annotations['entities'] = pd.concat([annotations['entities'],grab_entity_info(line)],ignore_index=True)
        if line[0][0] == 'R':
            # print(os.path.split(path)[1].replace(".ann",""))
            annotations['relations'] = pd.concat([annotations['relations'],grab_relation_info(line)],ignore_index=True)
        annotations['relations']['file'] = os.path.split(path)[1].replace(".ann","")
    return annotations

In [None]:
train_df = BRATtoDFconvert(TRAIN_DIR)
train_df.reset_index(drop=True,inplace=True)

In [None]:
train_df = train_df_load.copy()
train_df.reset_index(drop=True,inplace=True)

In [None]:
train_df.info()

In [None]:
print('sentences:', 'min =',str(train_df.sentences.str.len().min()) + ',','max =', str(train_df.sentences.str.len().max()))
print('matches:','min =',str(train_df.match.str.len().min()) + ',','max =', str(train_df.match.str.len().max()))

In [None]:
id2label = dict()
for idx, label in enumerate(train_df.string_id.value_counts().index):
  id2label[idx] = label

In [None]:
id2label

In [None]:
label2id = {v:k for k,v in id2label.items()}
label2id

In [None]:
train_df.shape

In [None]:
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        sentences = item.sentences
        entity_spans = [tuple(x) for x in item.entity_spans]

        encoding = tokenizer(sentences, entity_spans=entity_spans, padding='max_length', truncation=True, return_tensors="pt",max_length=257)

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        encoding["label"] = torch.tensor(label2id[item.string_id])

        return encoding

In [None]:
train_df, val_df = train_test_split(train_df_load, test_size=0.2, random_state=RANDOM_STATE, shuffle=True)

In [None]:
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=RANDOM_STATE, shuffle=True)
train_dataset = RelationExtractionDataset(data=train_df)
valid_dataset = RelationExtractionDataset(data=val_df)
# test_dataset = RelationExtractionDataset(data=test_df)

In [None]:
train_dataset[0].keys()

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2)
# test_dataloader = DataLoader(test_dataset, batch_size=2)

In [None]:
batch = next(iter(train_dataloader))
tokenizer.decode(batch["input_ids"][1])

In [None]:
id2label[batch["label"][1].item()]

In [None]:
class LUKE(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", num_labels=len(label2id))

    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label']
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss() # multi-class classification
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=5e-5)
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return valid_dataloader

    def test_dataloader(self):
        return test_dataloader

In [None]:
batch = next(iter(valid_dataloader))
labels = batch["label"]
batch.keys()

In [None]:
batch["input_ids"].shape

In [None]:
model = LUKE()
del batch["label"]
outputs = model(**batch)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

initial_loss = criterion(outputs.logits, labels)
print("Initial loss:", initial_loss)

In [None]:
wandb_logger = WandbLogger(name='luke-first-run-12000-articles-bis', project='LUKE')
# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=2,
    strict=False,
    verbose=False,
    mode='min'
)

trainer = Trainer(logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
loaded_model.model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)

predictions_total = []
labels_total = []
for batch in tqdm(test_dataloader):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to(device)

    # forward pass
    outputs = loaded_model.model(**batch)
    logits = outputs.logits
    predictions = logits.argmax(-1)
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())

In [None]:
print("Accuracy on test set:", accuracy_score(labels_total, predictions_total))

In [None]:
loaded_model = LUKE.load_from_checkpoint(checkpoint_path="/content/drive/Shareddrives/Datascouts/epoch=3-step=7699.ckpt")

In [None]:
test_df.iloc[0].sentence

In [None]:
idx = 2
text = test_df.iloc[idx].sentence
entity_spans = test_df.iloc[idx].entity_spans  # character-based entity spans
entity_spans = [tuple(x) for x in entity_spans]

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")

outputs = loaded_model.model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Sentence:", text)
print("Ground truth label:", test_df.iloc[idx].string_id)
print("Predicted class idx:", id2label[predicted_class_idx])
print("Confidence:", F.softmax(logits, -1).max().item())