In [None]:
import spacy
from spacy.tokens import DocBin
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
# from datasets import load_dataset
from collections import defaultdict
from typing import List, Dict, Tuple
from datasets import Dataset, DatasetDict
import ast

In [None]:
categories = ['CASE_NUMBER', 'COURT', 'DATE', 'GPE', 'JUDGE', 'LAWYER', 'ORG', 'OTHER_PERSON', 'PETITIONER', 'PRECEDENT', 'PROVISION', 'RESPONDENT', 'STATUTE', 'WITNESS']

# Create label2id and id2label dictionaries
B_PREFIX = 'B-'
I_PREFIX = 'I-'
O_TAG = 'O'
label2id = {O_TAG: 0}
id2label = {0: O_TAG}
idx = 1
for category in categories:
    label2id[B_PREFIX + category] = idx
    id2label[idx] = B_PREFIX + category
    idx += 1
    label2id[I_PREFIX + category] = idx
    id2label[idx] = I_PREFIX + category
    idx += 1

### Creating data [don't run if already created]

In [None]:
train = pd.read_csv("./data/finetuning/train.csv")
dev = pd.read_csv("./data/finetuning/dev.csv")
test = pd.read_csv("./data/finetuning/test.csv")

In [None]:
test

In [None]:
nlp = spacy.load('en_core_web_sm')
def tokenize_and_tag(df: pd.DataFrame, categories: List[str]) -> pd.DataFrame:
    # Define tag prefixes
    B_PREFIX = 'B-'
    I_PREFIX = 'I-'
    O_TAG = 'O'

    # Prepare output data
    output_data = {'tokens': [], 'ner_tags': []}

    for _, row in df.iterrows():
        sentence = row['sentence']
        entities = row['entities_dict']
        # print(entities)

        # Tokenize the sentence
        # tokens = sentence.split()  # Simple tokenization, can be replaced with a more robust tokenizer
        doc = nlp(sentence)
        tokens = [token.text for token in doc]

        # Initialize tags as 'Outside' for each token
        tags = [O_TAG for _ in tokens]

        entities = ast.literal_eval(entities)
        # print(type(entities))

        # Update tags based on entities
        for category, entity_list in entities.items():
            entity_lista = ast.literal_eval(entity_list)
            for entity in entity_lista:
                entity_tokens = entity.split()
                # Find all occurrences of the entity in the tokens
                for i in range(len(tokens)):
                    # print(entity_tokens, tokens[i:i+len(entity_tokens)])
                    if tokens[i:i+len(entity_tokens)] == entity_tokens:
                        # Update the tags for this occurrence of the entity
                        tags[i] = B_PREFIX + category
                        for j in range(i + 1, i + len(entity_tokens)):
                            tags[j] = I_PREFIX + category

        output_data['tokens'].append(tokens)
        output_data['ner_tags'].append(tags)
        data = pd.DataFrame(output_data) 
        data['ner_tags_str'] = data['ner_tags']
        data['ner_tags'] = data['ner_tags'].apply(lambda x: list(map(label2id.get, x)))

    return data


In [None]:
train_data = tokenize_and_tag(train, categories)
dev_data = tokenize_and_tag(dev, categories)
test_data = tokenize_and_tag(test, categories)

In [None]:
train_data.to_csv("./data/roberta/train.csv", index=False)
dev_data.to_csv("./data/roberta/dev.csv", index=False)
test_data.to_csv("./data/roberta/test.csv", index=False)

In [None]:
train_data

### Model Building

In [None]:
def transform_columns_to_list(df):
    for categ in df.columns:
        df[categ] = df[categ].apply(ast.literal_eval)
    return df

In [None]:
train_data = pd.read_csv("./data/roberta/train.csv")
dev_data = pd.read_csv("./data/roberta/dev.csv")
test_data = pd.read_csv("./data/roberta/test.csv")

In [None]:
train_data = transform_columns_to_list(train_data)
dev_data = transform_columns_to_list(dev_data)
test_data = transform_columns_to_list(test_data)

In [None]:
# Convert pandas DataFrames to Hugging Face's Dataset objects
train_dataset = Dataset.from_pandas(train_data)
dev_dataset = Dataset.from_pandas(dev_data)
test_dataset = Dataset.from_pandas(test_data)

# Create a DatasetDict
data = DatasetDict({
    'train': train_dataset,
    'validation': dev_dataset,
    'test': test_dataset
})


In [None]:
data

#### Tokenization

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
tokenizer.is_fast

In [None]:
inputs = data['train'][0]['tokens']
inputs = tokenizer(inputs, is_split_into_words=True)
print(inputs.tokens())

In [None]:
print(inputs.word_ids())

In [None]:
def align_labels_with_tokens(labels, word_ids):
  new_labels = []
  current_word=None
  for word_id in word_ids:
    if word_id != current_word:
      current_word = word_id
      label = -100 if word_id is None else labels[word_id]
      new_labels.append(label)

    elif word_id is None:
      new_labels.append(-100)

    else:
      label = labels[word_id]

      if label%2==1:
        label = label + 1
      new_labels.append(label)

  return new_labels

In [None]:
labels = data['train'][0]['ner_tags']
word_ids = inputs.word_ids()
print(labels, word_ids)

In [None]:
print(align_labels_with_tokens(labels, word_ids))

In [None]:
def tokenize_and_align_labels(examples):
  tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

  all_labels = examples['ner_tags']

  new_labels = []
  for i, labels in enumerate(all_labels):
    word_ids = tokenized_inputs.word_ids(i)
    new_labels.append(align_labels_with_tokens(labels, word_ids))

  tokenized_inputs['labels'] = new_labels

  return tokenized_inputs

In [None]:
tokenized_datasets = data.map(tokenize_and_align_labels, batched=True, remove_columns=data['train'].column_names)

In [None]:
tokenized_datasets

#### Data collation and Metrics

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
batch = data_collator([tokenized_datasets['train'][i] for i in range(2)])
print(batch)

#### Metrics

In [None]:
import evaluate
from seqeval.scheme import IOB2

metric = evaluate.load('seqeval')

In [None]:
import numpy as np

def compute_metrics(eval_preds):
  logits, labels = eval_preds

  predictions = np.argmax(logits, axis=-1)

  true_labels = [[id2label[l] for l in label if l!=-100] for label in labels]

  true_predictions = [[id2label[p] for p,l in zip(prediction, label) if l!=-100]
                      for prediction, label in zip(predictions, labels)]

  all_metrics = metric.compute(predictions=true_predictions, references=true_labels, scheme="IOB2", mode="strict", zero_division=0)

  return {"precision": all_metrics['overall_precision'],
          "recall": all_metrics['overall_recall'],
          "f1": all_metrics['overall_f1'],
          "accuracy": all_metrics['overall_accuracy']}

### Model training

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

In [None]:
# model = AutoModelForTokenClassification.from_pretrained(
#                                                     model_checkpoint,
#                                                     id2label=id2label,
#                                                     label2id=label2id)
model = AutoModelForTokenClassification.from_pretrained(
                                                    "./distilbert-finetuned-ner/checkpoint-3711",
                                                    id2label=id2label,
                                                    label2id=label2id)

In [None]:
model.config.num_labels

In [None]:
args = TrainingArguments("distilbert-finetuned-ner",
                         evaluation_strategy = "epoch",
                         save_strategy="epoch",
                         learning_rate = 2e-5,
                         num_train_epochs=3,
                         weight_decay=0.01)

In [None]:
trainer = Trainer(model=model,
                  args=args,
                  train_dataset = tokenized_datasets['train'],
                  eval_dataset = tokenized_datasets['validation'],
                  data_collator=data_collator,
                  compute_metrics=compute_metrics,
                  tokenizer=tokenizer)

In [None]:
trainer.train()