In [None]:
import os
import pandas as pd
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# 配置参数
MODEL_DIRS = [
    "./results/s1-85.27%",  
    "./results/s1-85.25%",  
    "./results/s2-85.225%", 
    "./results/s5-85.27%"   
]  
PKL_FILE = "test_unlabelled.pkl"  
OUTPUT_CSV = "ensembling_output.csv"  
BATCH_SIZE = 256  
NUM_LABELS = 4  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading models and tokenizers...")
tokenizers = [RobertaTokenizer.from_pretrained(model_dir) for model_dir in MODEL_DIRS]
models = [
    RobertaForSequenceClassification.from_pretrained(model_dir, num_labels=NUM_LABELS).to(device).eval()
    for model_dir in MODEL_DIRS
]

if not os.path.exists(PKL_FILE):
    raise FileNotFoundError(f"The specified .pkl file does not exist: {PKL_FILE}")

print(f"Loading dataset from: {PKL_FILE}")
unlabelled_data = pd.read_pickle(PKL_FILE)

if isinstance(unlabelled_data, pd.DataFrame):
    print("Converting DataFrame to Dataset...")
    dataset = Dataset.from_pandas(unlabelled_data)
elif isinstance(unlabelled_data, Dataset):
    print("Loaded Dataset directly from .pkl file.")
    dataset = unlabelled_data
else:
    raise ValueError(f"Unsupported data type in {PKL_FILE}: {type(unlabelled_data)}")

print("Tokenizing dataset...")
def preprocess(examples):
    return tokenizers[0](examples['text'], truncation=True, padding=True)

dataset = dataset.map(preprocess, batched=True, remove_columns=["text"])

data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=lambda x: tokenizers[0].pad(x, return_tensors="pt"))

print("Running inference with ensembling...")
all_predictions = []
for batch in tqdm(data_loader, desc="Inference Progress", leave=True):
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    total_logits = None

    for model in models:
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits  

            if total_logits is None:
                total_logits = logits
            else:
                total_logits += logits

    avg_logits = total_logits / len(models)

    predictions = torch.argmax(avg_logits, dim=-1)
    all_predictions.extend(predictions.cpu().numpy())  

print(f"Saving predictions to {OUTPUT_CSV}")
output_df = pd.DataFrame({
    "ID": range(len(all_predictions)),
    "Label": all_predictions
})
output_df.to_csv(OUTPUT_CSV, index=False)
print(f"Inference complete. Predictions saved to: {OUTPUT_CSV}")