In [None]:
# Setup
! pip install seqeval evaluate
! pip install kaleido



In [None]:
# Library imports
from transformers import AutoTokenizer, AutoModel, pipeline, AutoConfig, DistilBertForSequenceClassification, DistilBertModel, DistilBertConfig, DistilBertPreTrainedModel, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.tokenization_utils_base import BatchEncoding
from datasets import Dataset, DatasetDict
import torch
import torch.nn as nn
from google.colab import drive, userdata
import pickle
import random
import re
import time
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import plotly.express as px
import evaluate
import pprint
import kaleido
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
import re
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
# Mount drive
drive.mount("/content/drive")
%cd '/content/drive/MyDrive/Colab Notebooks/Math_Graph/pickle_files'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Colab Notebooks/Math_Graph/pickle_files


In [None]:
# View all pandas columns, rows
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
# Define file read function
def read_pickle(dict_file):
  with open(dict_file, 'rb') as file:
    return pickle.load(file)

In [None]:
# Load custom trained model

checkpoint = "Heather-Driver/distilbert-NER-LinearAlg-finetuned"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, do_lower_case=False)
distilbert_model = DistilBertModel.from_pretrained(checkpoint)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def extract_window(sentence, predicate, window_size):
  """This function creates a window around the matching predicate in order to tokenize and later get the span vectors for the window.
  The function adjusts according to the window size wanted"""
  tokens = sentence.split()

  # Find the starting index of the predicate in the sentence (find all word positions for the predicate)
  pattern = re.escape(predicate)  # Escape the predicate string to handle special characters if any
  match = re.search(pattern, sentence)

  if not match:
      return "Predicate not found in the sentence."

  # Get the index of where the predicate starts in the list of tokens
  start_index = len(sentence[:match.start()].split())  # Token index of the start of predicate

  # Define the sample window
  start_window = max(0, start_index - window_size)
  end_window = min(len(tokens), start_index + len(predicate.split()) + window_size)

  # Create the window of words around the predicate
  window = tokens[start_window:end_window]

  # If the window is too short at the beginning or the end, adjust to take as many as possible
  if start_window == 0:
      # If the window is at the start, extend the end if possible
      end_window = min(len(tokens), start_index + len(predicate.split()) + window_size)
  if end_window == len(tokens):
      # If the window is at the end, extend the start if possible
      start_window = max(0, start_index - window_size)

  # Create the window of words around the predicate again after adjustments
  window = tokens[start_window:end_window]
  return ' '.join(window)

def adds_context_window(window_size, df):
  for i in range(len(df)):
    text = extract_window(df.at[i, 'sentence'], df.at[i, 'predicate'], window_size=window_size)
    df.at[i, 'context_window'] = text
  return df

In [None]:
# Read in dictionary
predicate_data = read_pickle('predicate_data.pkl')
predicate_data = predicate_data.rename(columns={'Window_1': 'context_window', 'Label': 'string_label'})
predicate_data.columns = predicate_data.columns.str.lower()

In [None]:
# Need mapping of classification tags to their indices for model to use

index2tag = {idx:tag for idx, tag in enumerate(predicate_data['string_label'].unique())} # This is just a nonsignificant arbitrary mapping of the label to a number for training the model
tag2index = {tag:idx for idx, tag in enumerate(predicate_data['string_label'].unique())} # To lookup indices from tags

In [None]:
predicate_data['label'] = predicate_data['string_label'].map(tag2index)

In [None]:
predicate_data = adds_context_window(window_size=4, df=predicate_data)

In [None]:
predicate_data.head(2)

Unnamed: 0,sentence,subject,predicate,object,string_label,context_window,label
0,The Wishart distribution is used in multivaria...,wishart distribution,is used in,multivariate statistics,used in,The Wishart distribution is used in multivaria...,0
1,The Square Root Method is transformed by the a...,Square Root Method,transformed by,the application of inverse operations to deriv...,computation,Square Root Method is transformed by the appli...,1


In [None]:
X_train_indices, X_test_indices, y_train_indices, y_test_indices = train_test_split(predicate_data.index.to_numpy(), predicate_data['label'].to_numpy(),
                                                                                    test_size=0.05, random_state=42, stratify=predicate_data['label'].to_numpy())

# Repeat to get validation sub-sample of Train
X_train_indices, X_valid_indices, y_train_indices, y_valid_indices = train_test_split(X_train_indices, y_train_indices, test_size=0.3, random_state=42, stratify=y_train_indices)

## Preprocessing

In [None]:
dataset = Dataset.from_pandas(predicate_data[['sentence', 'label', 'context_window', 'predicate']])

In [None]:
# Select subsets of the dataset for train, test and validation
train_split = dataset.select(X_train_indices)
test_split = dataset.select(X_test_indices)
valid_split = dataset.select(X_valid_indices)

dataset = DatasetDict({
    'train': train_split,
    'test': test_split,
    'validation': valid_split
})

In [None]:
def preprocess_function_predicate(examples):
  context_inputs = tokenizer(examples["predicate"], return_tensors="pt", add_special_tokens=False, truncation=True, padding="max_length", max_length=256)
  return context_inputs

dataset = dataset.map(preprocess_function_predicate, batched=True)
# Rename the 'attention_mask' column to 'context_attention_mask'
dataset = dataset.rename_columns({"attention_mask": "predicate_attention_mask", "input_ids": "predicate_input_ids"})

Map:   0%|          | 0/292 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

Map:   0%|          | 0/126 [00:00<?, ? examples/s]

In [None]:
def preprocess_function_context(examples):
  context_inputs = tokenizer(examples["context_window"], return_tensors="pt", add_special_tokens=False, truncation=True, padding="max_length", max_length=256)
  return context_inputs

dataset = dataset.map(preprocess_function_context, batched=True)
# Rename the 'attention_mask' column to 'context_attention_mask'
dataset = dataset.rename_columns({"attention_mask": "context_attention_mask", "input_ids": "context_input_ids"})

Map:   0%|          | 0/292 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

Map:   0%|          | 0/126 [00:00<?, ? examples/s]

In [None]:
def preprocess_function(examples):
  inputs = tokenizer(examples["sentence"], return_tensors="pt", add_special_tokens=True, truncation=True, padding="max_length", max_length=256)
  return inputs

dataset = dataset.map(preprocess_function, batched=True)

dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label', 'context_attention_mask', 'context_input_ids', 'predicate_attention_mask', 'predicate_input_ids'])

Map:   0%|          | 0/292 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

Map:   0%|          | 0/126 [00:00<?, ? examples/s]

## Developing the Model Parameters

In [None]:
class DistilBertForSentenceClassificationSpan(DistilBertPreTrainedModel):
  config_class = DistilBertConfig

  def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels
    # Model body
    self.distilbert = distilbert_model
    # Attention mechanism for context
    self.attention_w = nn.Parameter(torch.randn(config.hidden_size))  # Trainable attention vector
    self.attention_bias = nn.Parameter(torch.zeros(1))  # Bias term for attention
    # Lookup table for context width embeddings
    self.width_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    # Classification head
    self.classifier = nn.Linear(768 * 4, self.num_labels)  # Span * 4 (CLS, span, width, predicate)

  def _predicate_weight_embedding(self, predicate_input_ids, predicate_attention_mask):
    outputs = self.distilbert(input_ids=predicate_input_ids, attention_mask=predicate_attention_mask, output_attentions=False)
    predicate_embeddings = outputs.last_hidden_state  # shape: [batch_size, seq_len, hidden_size] torch.Size([batch_size, 512, 768]) --> 512 to 256
    # print('context_embeddings', context_embeddings.shape)
    # attention scores
    attention_scores = torch.matmul(predicate_embeddings, self.attention_w) + self.attention_bias  # shape: [batch_size, seq_len, 1] torch.Size([1, 512, 1])
    # print('attention_scores', attention_scores.shape)

    # get attention weights from softmax
    attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)  # shape: [batch_size, num_heads, seq_len, seq_len] torch.Size([1, 12, 512, 512])
    # print('attention_weights', attention_weights.shape)
    attention_weights = attention_weights.unsqueeze(-1)
    # print('attention_weights', attention_weights.shape)

    # # Weighted sum of span embeddings to get the final attention span representation
    # weighted_context_embeddings = attention_weights * context_embeddings.unsqueeze(2)  # shape: [1, 12, 512, 512, 768]

    weighted_span_embeddings = torch.sum(attention_weights * predicate_embeddings, dim=1)  # shape: [1, 768]
    # print('weighted_span_embeddings', weighted_span_embeddings.shape)
    return weighted_span_embeddings #span representation [1, 768]

  def _attention_weight_embedding(self, context_input_ids, context_attention_mask):
    outputs = self.distilbert(input_ids=context_input_ids, attention_mask=context_attention_mask, output_attentions=False)
    context_embeddings = outputs.last_hidden_state  # shape: [batch_size, seq_len, hidden_size] torch.Size([batch_size, 512, 768]) --> 512 to 256
    # print('context_embeddings', context_embeddings.shape)
    # attention scores
    attention_scores = torch.matmul(context_embeddings, self.attention_w) + self.attention_bias  # shape: [batch_size, seq_len, 1] torch.Size([1, 512, 1])
    # print('attention_scores', attention_scores.shape)

    # get attention weights from softmax
    attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)  # shape: [batch_size, num_heads, seq_len, seq_len] torch.Size([1, 12, 512, 512])
    # print('attention_weights', attention_weights.shape)
    attention_weights = attention_weights.unsqueeze(-1)
    # print('attention_weights', attention_weights.shape)

    # # Weighted sum of span embeddings to get the final attention span representation
    # weighted_context_embeddings = attention_weights * context_embeddings.unsqueeze(2)  # shape: [1, 12, 512, 512, 768]

    weighted_window_embeddings = torch.sum(attention_weights * context_embeddings, dim=1)  # shape: [1, 768]
    # print('weighted_span_embeddings', weighted_span_embeddings.shape)

    return weighted_window_embeddings #span representation [1, 768]

  def _cls_embeddings(self, input_ids, attention_mask):
    outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
    embeddings = outputs.last_hidden_state # shape [1, 512, 768]
    cls_embedding = embeddings[:, 0:1, :].squeeze(1)  # squeeze converts [1, 1, 768] to [1, 768]
    return cls_embedding

  def _width_embeddings(self, context_input_ids):
    span_length = context_input_ids.ne(0).sum(dim=1) -2
    # Ensure no negative or zero indices (minimum span length should be 1)
    span_length = torch.clamp(span_length, min=1)
    width_embedding = self.width_embedding(span_length)
    return width_embedding

  def forward(self, context_input_ids=None, context_attention_mask=None, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, predicate_input_ids=None, predicate_attention_mask=None, **kwargs):
    # Get embeddings from attention weights
    attention_weight_embedding = self._attention_weight_embedding(context_input_ids=context_input_ids, context_attention_mask=context_attention_mask)
    # Get embeddings from predicate weights
    attention_weight_predicate = self._predicate_weight_embedding(predicate_input_ids=predicate_input_ids, predicate_attention_mask=predicate_attention_mask)
    # Get width embedding
    width_embedding = self._width_embeddings(context_input_ids=context_input_ids)
    # Get CLS token embedding
    cls_embedding = self._cls_embeddings(input_ids=input_ids, attention_mask=attention_mask)
    # Concatenate span representation, [CLS] embedding, and width embedding
    final_representation = torch.cat((attention_weight_embedding, cls_embedding, width_embedding, attention_weight_predicate), dim=-1)  # shape: [batch_size, 768*4]
    # Classifier on concat
    logits = self.classifier(final_representation)
    # Loss calc
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    return SequenceClassifierOutput(loss=loss, logits=logits)

In [None]:
config = DistilBertConfig.from_pretrained("Heather-Driver/distilbert-NER-LinearAlg-finetuned")
config.label2id = tag2index
config.id2label = index2tag
config.num_labels = len(index2tag)

model = DistilBertForSentenceClassificationSpan(config)
model.to(device)
model.gradient_checkpointing_enable()

In [None]:
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  f1 = f1_score(labels, preds, average="weighted") #y_true, y_pred
  acc = accuracy_score(labels, preds)
  precision = precision_score(labels, preds, average="weighted")
  recall = recall_score(labels, preds, average="weighted")
  return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [None]:
num_epochs = 20
batch_size = 8

learning_rate = 1e-5
model_name = f"distilbert-classn-LinearAlg-finetuned-pred-span-width-4"

training_arguments = TrainingArguments(
    output_dir=model_name,
    log_level="error",
    logging_strategy="steps",
    logging_steps=50,
    weight_decay=0.001,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,  # Increase eval batch size
    learning_rate=1e-5,  # lower learning rate
    eval_strategy="steps",  # Evaluate more frequently
    save_strategy="epoch",
    disable_tqdm=False,
    save_steps=1000000,
    remove_unused_columns=False,
    push_to_hub=True,
    no_cuda=False,
    gradient_checkpointing=True,
    gradient_accumulation_steps=2,  # Simulate larger batch sizes
    lr_scheduler_type="linear",  # Use linear learning rate scheduler
    warmup_steps=500,  # Add learning rate warm-up
)

data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    return_tensors="pt",
)

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Create the DataLoader for the training set
train_dataloader = DataLoader(
    dataset['train'],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_collator,
    num_workers=4
)

training_arguments.fp16 = True

  trainer = Trainer(


In [None]:
torch.cuda.empty_cache()

In [None]:
# Run on a single batch to check

for batch in train_dataloader:
  # Move batch to device
  context_input_ids = batch['context_input_ids'].to(device) #context_input_ids, context_attention_mask
  context_attention_mask = batch['context_attention_mask'].to(device)
  labels = batch['labels'].to(device)
  input_ids = batch['input_ids'].to(device)
  attention_mask = batch['attention_mask'].to(device)
  predicate_input_ids = batch['predicate_input_ids'].to(device)
  predicate_attention_mask = batch['predicate_attention_mask'].to(device)

  # Get model outputs
  outputs = model(context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, input_ids=input_ids, attention_mask=attention_mask, predicate_input_ids=predicate_input_ids, predicate_attention_mask=predicate_attention_mask)

  # Convert logits to CPU and numpy for debugging (if needed)
  logits_np = outputs.logits.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  print(f"logits_np shape: {logits_np.shape}")  # Should be (batch_size, seq_length, num_labels) logits_np shape: (5, 512, 66)
  print(f"labels_np shape: {labels_np.shape}")  # Should be (batch_size, seq_length)

  preds = logits_np.argmax(-1)
  print(f"preds: {preds}")
  print(f"labels_np: {labels_np}")
  f1 = f1_score(labels_np, preds, average="weighted") #y_true, y_pred
  print(f"f1_score: {f1}")

  break


logits_np shape: (8, 11)
labels_np shape: (8,)
preds: [9 9 0 0 9 0 0 0]
labels_np: [ 4 10  0  6  2  5  5  3]
f1_score: 0.041666666666666664


In [None]:
torch.cuda.empty_cache()
trainer.train()
trainer.push_to_hub(commit_message="Classification Training")

[34m[1mwandb[0m: Currently logged in as: [33mheather-rink[0m ([33mh-driver[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
50,5.0401,2.53896,0.111111,0.043101,0.036617,0.111111
100,4.9232,2.493735,0.087302,0.04645,0.041629,0.087302
150,4.8797,2.421084,0.079365,0.048745,0.038697,0.079365
200,4.7269,2.342468,0.18254,0.186092,0.205357,0.18254
250,4.6528,2.263988,0.190476,0.204575,0.278151,0.190476
300,4.4653,2.11028,0.436508,0.409938,0.448673,0.436508
350,4.126,1.93496,0.515873,0.513696,0.572336,0.515873
400,3.7215,1.739146,0.611111,0.591064,0.614745,0.611111
450,3.1218,1.513358,0.642857,0.623743,0.645748,0.642857
500,2.6488,1.302732,0.674603,0.666708,0.705763,0.674603


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


CommitInfo(commit_url='https://huggingface.co/Heather-Driver/distilbert-classn-LinearAlg-finetuned-pred-span-width-4/commit/abd7820b4763a8499f10645bf954c2f8053466f7', commit_message='Classification Training', commit_description='', oid='abd7820b4763a8499f10645bf954c2f8053466f7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Heather-Driver/distilbert-classn-LinearAlg-finetuned-pred-span-width-4', endpoint='https://huggingface.co', repo_type='model', repo_id='Heather-Driver/distilbert-classn-LinearAlg-finetuned-pred-span-width-4'), pr_revision=None, pr_num=None)

In [None]:
# %cd '/content/drive/MyDrive/Colab Notebooks/Math_Graph'
# model.save_pretrained("linalg_ner_span_model")
# tokenizer.save_pretrained("linalg_ner_span_model")