In [None]:
import os
import pandas as pd
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

CHECKPOINT_DIR = "./checkpoint-2200" 
PKL_FILE = "test_unlabelled.pkl" 
OUTPUT_CSV = "inference_output.csv"  
BATCH_SIZE = 256 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Loading model from checkpoint: {CHECKPOINT_DIR}")
tokenizer = RobertaTokenizer.from_pretrained(CHECKPOINT_DIR)
model = RobertaForSequenceClassification.from_pretrained(CHECKPOINT_DIR, num_labels=4)
model.to(device)
model.eval()

if not os.path.exists(PKL_FILE):
    raise FileNotFoundError(f"The specified .pkl file does not exist: {PKL_FILE}")

print(f"Loading dataset from: {PKL_FILE}")
unlabelled_data = pd.read_pickle(PKL_FILE)

if isinstance(unlabelled_data, pd.DataFrame):
    print("Converting DataFrame to Dataset...")
    dataset = Dataset.from_pandas(unlabelled_data)
elif isinstance(unlabelled_data, Dataset):
    print("Loaded Dataset directly from .pkl file.")
    dataset = unlabelled_data
else:
    raise ValueError(f"Unsupported data type in {PKL_FILE}: {type(unlabelled_data)}")

def preprocess(examples):
    return tokenizer(examples['text'], truncation=True, padding=True)

print("Tokenizing dataset...")
dataset = dataset.map(preprocess, batched=True, remove_columns=["text"])

data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=lambda x: tokenizer.pad(x, return_tensors="pt"))

print("Running inference...")
all_predictions = []
for batch in tqdm(data_loader, desc="Inference Progress", leave=True):
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    predictions = torch.argmax(outputs.logits, dim=-1)
    all_predictions.extend(predictions.cpu().numpy()) 

print(f"Saving predictions to {OUTPUT_CSV}")
output_df = pd.DataFrame({
    "ID": range(len(all_predictions)),
    "Label": all_predictions
})
output_df.to_csv(OUTPUT_CSV, index=False)
print(f"Inference complete. Predictions saved to: {OUTPUT_CSV}")