In [None]:
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from datasets import Dataset
import pandas as pd
import torch
from tqdm import tqdm

In [None]:
import json

In [None]:
from sklearn.metrics import classification_report

# 01 Predict on test data with zero shot model

facebook/bart-large-mnli

## 8 categories

In [None]:
torch.cuda.is_available()

In [None]:
# test_df = spark.read.parquet("data/df_test.parquet").toPandas()
test_df = pd.read_parquet("data/df_test.parquet.gzip")

In [None]:
display(test_df)

In [None]:
with open("models/finetuned_scibert_scivocab_uncased_8cats/id2label.json", "r") as f:
    id2label = {int(k): v for k, v in json.load(f).items()}

id2label

In [None]:
candidate_labels = [id2label[i] for i in range(len(id2label))]
candidate_labels

In [None]:
zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device="cuda")

In [None]:
zero_shot_preds = []
batch_size = 32  # Adjust based on memory
texts = test_df["text"].tolist()

for i in tqdm(range(0, len(texts), batch_size)):
    batch = texts[i:i+batch_size]
    res_batch = zero_shot(batch, candidate_labels)
    
    if isinstance(res_batch, dict):  # Happens if batch size = 1
        zero_shot_preds.append(res_batch["labels"][0])
    else:
        zero_shot_preds.extend([res["labels"][0] for res in res_batch])

In [None]:
# help(pipeline)

In [None]:
test_df_w_preds = test_df
test_df_w_preds['pred'] = zero_shot_preds

In [None]:
display(test_df_w_preds)

In [None]:
test_df_w_preds.to_parquet("data/df_test_pred_zeroshot_8cats.parquet.gzip")

## 17 categories

In [None]:
test_df = pd.read_parquet("data/df_test_17cats.parquet.gzip")

In [None]:
with open("models/finetuned_scibert_scivocab_uncased_weighted_17cats/id2label.json", "r") as f:
    id2label = {int(k): v for k, v in json.load(f).items()}

id2label

In [None]:
candidate_labels = [id2label[i] for i in range(len(id2label))]
candidate_labels

In [None]:
zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device="cuda")

In [None]:
zero_shot_preds = []
batch_size = 32  # Adjust based on memory
texts = test_df["text"].tolist()

for i in tqdm(range(0, len(texts), batch_size)):
    batch = texts[i:i+batch_size]
    res_batch = zero_shot(batch, candidate_labels)
    
    if isinstance(res_batch, dict):  # Happens if batch size = 1
        zero_shot_preds.append(res_batch["labels"][0])
    else:
        zero_shot_preds.extend([res["labels"][0] for res in res_batch])

In [None]:
test_df_w_preds = test_df
test_df_w_preds['pred'] = zero_shot_preds

In [None]:
display(test_df_w_preds)

In [None]:
test_df_w_preds.to_parquet("data/df_test_pred_zeroshot_17cats.parquet.gzip")

# 02 Evaluate predictions

## 8 categories

### Zero-shot

In [None]:
test_df_w_preds_zeroshot = pd.read_parquet("data/df_test_pred_zeroshot_8cats.parquet.gzip")
display(test_df_w_preds_zeroshot)

In [None]:
actual = test_df_w_preds_zeroshot.group
predicted_zeroshot = test_df_w_preds_zeroshot.pred

In [None]:
report_zeroshot_dict = classification_report(actual, predicted_zeroshot, output_dict=True)

report_zeroshot_df = pd.DataFrame(report_zeroshot_dict).transpose()
report_zeroshot_df

### Finetuned unweighted

In [None]:
test_df_w_preds_finetuned = pd.read_parquet("data/df_test_pred_finetuned_8cats.parquet.gzip")
display(test_df_w_preds_finetuned)

In [None]:
actual = test_df_w_preds_finetuned.group
predicted_finetuned = test_df_w_preds_finetuned.pred

In [None]:
report_finetuned_dict = classification_report(actual, predicted_finetuned, output_dict=True)

report_finetuned_df = pd.DataFrame(report_finetuned_dict).transpose()
report_finetuned_df

### Finetuned weighted

In [None]:
test_df_w_preds_finetuned_w = pd.read_parquet("data/df_test_pred_finetuned_weighted_8cats.parquet.gzip")
display(test_df_w_preds_finetuned_w)

actual = test_df_w_preds_finetuned_w.group
predicted_finetuned = test_df_w_preds_finetuned_w.pred

In [None]:
report_finetuned_w_dict = classification_report(actual, predicted_finetuned, output_dict=True)

report_finetuned_w_df = pd.DataFrame(report_finetuned_w_dict).transpose()
report_finetuned_w_df

## 17 categories

### Zero-shot

In [None]:
test_df_w_preds_zeroshot_17cats = pd.read_parquet("data/df_test_pred_zeroshot_17cats.parquet.gzip")
display(test_df_w_preds_zeroshot_17cats)

In [None]:
actual = test_df_w_preds_zeroshot_17cats.subgroup
predicted_zeroshot = test_df_w_preds_zeroshot_17cats.pred

In [None]:
report_zeroshot_17cats_dict = classification_report(actual, predicted_zeroshot, output_dict=True)

report_zeroshot_17cats_df = pd.DataFrame(report_zeroshot_17cats_dict).transpose()
report_zeroshot_17cats_df

### Finetuned weighted

In [None]:
test_df_w_preds_finetuned_w_17cats = pd.read_parquet("data/df_test_pred_finetuned_weighted_17cats.parquet.gzip")
display(test_df_w_preds_finetuned_w_17cats)

actual = test_df_w_preds_finetuned_w_17cats.subgroup
predicted_finetuned = test_df_w_preds_finetuned_w_17cats.pred

In [None]:
report_finetuned_w_17cats_dict = classification_report(actual, predicted_finetuned, output_dict=True)

report_finetuned_w_17cats_df = pd.DataFrame(report_finetuned_w_17cats_dict).transpose()
report_finetuned_w_17cats_df