In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split
from pyspark.ml.feature import StringIndexer
# from pyspark.sql.types import StructType, StructField, StringType, FloatType

In [None]:
spark = SparkSession.builder.getOrCreate()

In [None]:
import torch
import json
import pandas as pd

In [None]:
from datasets import Dataset

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback

In [None]:
# Use MPS if available on Mac, fallback to CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

In [None]:
train_df = spark.read.parquet("data/df_train.parquet.gzip")
val_df = spark.read.parquet("data/df_val.parquet.gzip")

In [None]:
train_df.show()

In [None]:
# Encode labels
indexer = StringIndexer(inputCol="group", outputCol="label", handleInvalid="keep") # handleInvalid : in case new labels occur in the unseen data
indexer_model = indexer.fit(train_df)
train_df = indexer_model.transform(train_df).drop(col("group"))

train_df.show()

In [None]:
val_df = indexer_model.transform(val_df).drop(col("group"))
# val_df.show()

In [None]:
labels_list = indexer_model.labels

# with open("data/label_mapping.json", "w") as f:
#     json.dump(labels_list, f)

In [None]:
labels_list

In [None]:
# Build mappings
id2label = {i: label for i, label in enumerate(labels_list)}
label2id = {label: i for i, label in enumerate(labels_list)}

id2label 

In [None]:
with open("models/finetuned_scibert_scivocab_uncased_8cats/id2label.json", "w") as f:
    json.dump(id2label, f)

In [None]:
# pandas df
train_pdf = train_df.select("text", "label").toPandas()
val_pdf = val_df.select("text", "label").toPandas()

In [None]:
train_pdf["label"] = train_pdf["label"].astype(int)
val_pdf["label"] = val_pdf["label"].astype(int)

In [None]:
# HuggingFace dataset
train_hds = Dataset.from_pandas(train_pdf)
val_hds = Dataset.from_pandas(val_pdf)

In [None]:
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
print(tokenizer.model_max_length)

In [None]:
# based on mean text length, truncate at 256 words
token_lengths = train_pdf["text"].apply(lambda x: len(tokenizer.tokenize(x)))
print(token_lengths.describe())

In [None]:
max_len = 256

def tokenize_function(line):
    return tokenizer(line["text"], padding= "max_length", truncation=True, max_length = max_len)

tokenized_train = train_hds.map(tokenize_function, batched=True)
tokenized_val = val_hds.map(tokenize_function, batched=True)

In [None]:
# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=8,
    id2label=id2label,
    label2id=label2id
)

In [None]:
# Training setup
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

In [None]:
# Accuracy metric
import evaluate
accuracy = evaluate.load("accuracy")

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.from_numpy(logits).argmax(dim=-1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    compute_metrics=compute_metrics,
    processing_class=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # stop after 2 bad evals
)

In [None]:
trainer.train()

In [None]:
# Save model
model.save_pretrained("models/finetuned_scibert_scivocab_uncased_8cats")
tokenizer.save_pretrained("models/finetuned_scibert_scivocab_uncased_8cats")

# With class weights in loss function

To account for class imbalance (together with having made a stratified train / val / test split)

In [None]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

In [None]:
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(train_pdf["label"]),
    y=train_pdf["label"]
)

class_weights

In [None]:
np.unique(train_pdf["label"])

In [None]:
id2label 

In [None]:
# Convert to tensor
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float)

In [None]:
def compute_weighted_loss(outputs, labels, num_items_in_batch):
    logits = outputs["logits"]
    labels = labels.long()
    return torch.nn.functional.cross_entropy(logits, labels, weight=class_weights_tensor.to(logits.device))

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=8,
    id2label=id2label,
    label2id=label2id
)

In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    compute_metrics=compute_metrics,
    processing_class=tokenizer, 
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],  # stop after 2 bad evals
    compute_loss_func=compute_weighted_loss,
)


In [None]:
# help(Trainer)

In [None]:
trainer.train()

In [None]:
# Save model
model.save_pretrained("models/finetuned_scibert_scivocab_uncased_weighted_8cats")
tokenizer.save_pretrained("models/finetuned_scibert_scivocab_uncased_weighted_8cats")