In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    get_scheduler
)
from torch.optim import AdamW
from tqdm import tqdm
import os
import shutil
import numpy as np

In [None]:
# Select the targeted dataset

# First level datasets
# DATASET_MAP = {
#     "46985": "drive/MyDrive/WoS/original_dataset/46985_xydata_l1_l2.csv",
#     "11967": "drive/MyDrive/WoS/original_dataset/11967_xydata_l1_l2.csv",
#     "5736":  "drive/MyDrive/WoS/original_dataset/5736_xydata_l1_l2.csv"
# }

# Second level datasets

DATASET_MAP = {
    "46985": "drive/MyDrive/WoS/original_dataset/46985_xydata.csv",
    "11967": "drive/MyDrive/WoS/original_dataset/11967_xydata.csv",
    "5736":  "drive/MyDrive/WoS/original_dataset/5736_xydata.csv"
}

In [None]:
selected_dataset = "46985"   # select: "46985", "11967", "5736"

In [None]:
df = pd.read_csv(DATASET_MAP[selected_dataset])
print(f"Loaded dataset {selected_dataset} with shape {df.shape}")

num_labels = df["Y"].nunique() # num_labels = df["YL1"].nunique() # for first level
print("Number of labels:", num_labels)

In [None]:
# First level
# train_texts, test_texts, train_labels, test_labels = train_test_split(
#     df["X"], df["YL1"],
#     test_size=0.2,
#     random_state=42,
#     stratify=df["YL1"]
# )

# Second level

train_texts, test_texts, train_labels, test_labels = train_test_split(
    df["X"], df["Y"],
    test_size=0.2,
    random_state=42,
    stratify=df["Y"]
)

_, val_texts, _, val_labels = train_test_split(
    test_texts, test_labels,
    test_size=0.2,
    random_state=42,
    stratify=test_labels
)

In [None]:
print(f"Train size: {len(train_texts)}")
print(f"Validation size: {len(val_texts)}")
print(f"Test size: {len(test_texts)}")


In [None]:
MODEL_MAP = {
"scibert": "allenai/scibert_scivocab_uncased",
"bert": "bert-base-uncased"
}
scibert_name = MODEL_MAP["scibert"]
bert_name = MODEL_MAP["bert"]

In [None]:
tokenizer_sci = AutoTokenizer.from_pretrained(scibert_name)
tokenizer_bert = AutoTokenizer.from_pretrained(bert_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer_sci, tokenizer_bert):
        self.encodings_sci = tokenizer_sci(list(texts), padding=True, truncation=True, max_length=256, return_tensors="pt")
        self.encodings_bert = tokenizer_bert(list(texts), padding=True, truncation=True, max_length=256, return_tensors="pt")
        self.labels = torch.tensor(labels.values, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids_sci": self.encodings_sci["input_ids"][idx],
            "attention_mask_sci": self.encodings_sci["attention_mask"][idx],
            "input_ids_bert": self.encodings_bert["input_ids"][idx],
            "attention_mask_bert": self.encodings_bert["attention_mask"][idx],
            "labels": self.labels[idx]
        }

train_dataset = TextDataset(train_texts, train_labels, tokenizer_sci, tokenizer_bert)
val_dataset   = TextDataset(val_texts, val_labels, tokenizer_sci, tokenizer_bert)
test_dataset  = TextDataset(test_texts, test_labels, tokenizer_sci, tokenizer_bert)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16)
test_loader  = DataLoader(test_dataset, batch_size=16)

In [None]:
learning_rates = [2e-5, 5e-6, 2e-6, 1e-6]
patience = 7
epochs = 25
num_warmup_steps = 1e-4

# curriculum pseudo-label thresholds
initial_threshold = 0.95
final_threshold = 0.80

# adaptive ensemble candidates
ensemble_candidates = [0.3, 0.5, 0.7]

best_overall_f1 = 0.0
best_model_path = "/content/drive/MyDrive/WoS/best_models_co_training_2"
os.makedirs(best_model_path, exist_ok=True)

best_model_final_path = None
best_alpha_final = None

# ================= CO-TRAINING LOOP ====================
for lr in learning_rates:
    print(f"\n=== Training with learning rate {lr} ===")

    scibert = AutoModelForSequenceClassification.from_pretrained(scibert_name, num_labels=num_labels).to(device)
    bert = AutoModelForSequenceClassification.from_pretrained(bert_name, num_labels=num_labels).to(device)

    optim_scibert = AdamW(scibert.parameters(), lr=lr, eps=1e-8)
    optim_bert = AdamW(bert.parameters(), lr=lr, eps=1e-8)

    num_training_steps = epochs * len(train_loader)
    sched_scibert = get_scheduler("linear", optim_scibert, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    sched_bert = get_scheduler("linear", optim_bert, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

    best_val_f1 = 0.0
    best_alpha = 0.7
    patience_counter = 0
    local_best_path = f"{best_model_path}/cotrain_{selected_dataset}_lr{lr}"
    os.makedirs(local_best_path, exist_ok=True)

    for epoch in range(epochs):
        # curriculum pseudo-threshold
        if epochs > 1:
            pseudo_threshold = initial_threshold - (epoch / (epochs - 1)) * (initial_threshold - final_threshold)
        else:
            pseudo_threshold = final_threshold
        print(f"Epoch {epoch+1}/{epochs} - Pseudo-threshold: {pseudo_threshold:.3f}")

        scibert.train()
        bert.train()
        loop = tqdm(train_loader, leave=True)

        for batch in loop:
            batch_gpu = {k: v.to(device) for k, v in batch.items()}

            # labeled loss
            out_sci = scibert(input_ids=batch_gpu["input_ids_sci"], attention_mask=batch_gpu["attention_mask_sci"], labels=batch_gpu["labels"])
            out_bert = bert(input_ids=batch_gpu["input_ids_bert"], attention_mask=batch_gpu["attention_mask_bert"], labels=batch_gpu["labels"])
            loss_labeled = 0.5 * (out_sci.loss + out_bert.loss)

            # pseudo-label exchange
            with torch.no_grad():
                probs_sci = torch.softmax(out_sci.logits, dim=-1)
                probs_bert = torch.softmax(out_bert.logits, dim=-1)
                max_probs_sci, pseudo_labels_sci = torch.max(probs_sci, dim=-1)
                max_probs_bert, pseudo_labels_bert = torch.max(probs_bert, dim=-1)
                mask_sci = max_probs_sci > pseudo_threshold
                mask_bert = max_probs_bert > pseudo_threshold

            pseudo_loss_bert = torch.tensor(0.0).to(device)
            if mask_sci.any():
                pseudo_batch_bert = {
                    "input_ids": batch_gpu["input_ids_bert"][mask_sci],
                    "attention_mask": batch_gpu["attention_mask_bert"][mask_sci],
                    "labels": pseudo_labels_sci[mask_sci]
                }
                conf_weight = max_probs_sci[mask_sci].mean().item()
                pseudo_loss_bert = conf_weight * bert(**pseudo_batch_bert).loss

            pseudo_loss_sci = torch.tensor(0.0).to(device)
            if mask_bert.any():
                pseudo_batch_sci = {
                    "input_ids": batch_gpu["input_ids_sci"][mask_bert],
                    "attention_mask": batch_gpu["attention_mask_sci"][mask_bert],
                    "labels": pseudo_labels_bert[mask_bert]
                }
                conf_weight = max_probs_bert[mask_bert].mean().item()
                pseudo_loss_sci = conf_weight * scibert(**pseudo_batch_sci).loss

            total_loss = loss_labeled + 0.5 * (pseudo_loss_sci + pseudo_loss_bert)

            optim_scibert.zero_grad()
            optim_bert.zero_grad()
            total_loss.backward()
            optim_scibert.step()
            sched_scibert.step()
            optim_bert.step()
            sched_bert.step()

            loop.set_description(f"Epoch {epoch+1}/{epochs}")
            loop.set_postfix(loss=total_loss.item())

        # ---- Validation with adaptive ensemble ----
        scibert.eval()
        bert.eval()
        all_logits_sci, all_logits_bert, all_labels = [], [], []

        with torch.no_grad():
            for batch in val_loader:
                batch_gpu = {k: v.to(device) for k, v in batch.items()}
                out_sci = scibert(input_ids=batch_gpu["input_ids_sci"], attention_mask=batch_gpu["attention_mask_sci"])
                out_bert = bert(input_ids=batch_gpu["input_ids_bert"], attention_mask=batch_gpu["attention_mask_bert"])
                all_logits_sci.append(out_sci.logits.cpu().numpy())
                all_logits_bert.append(out_bert.logits.cpu().numpy())
                all_labels.append(batch_gpu["labels"].cpu().numpy())

        all_logits_sci = np.vstack(all_logits_sci)
        all_logits_bert = np.vstack(all_logits_bert)
        all_labels = np.concatenate(all_labels)

        best_alpha_epoch, best_f1_epoch = 0.7, 0.0
        for alpha in ensemble_candidates:
            beta = 1 - alpha
            logits = alpha * all_logits_sci + beta * all_logits_bert
            preds = logits.argmax(axis=-1)
            f1 = f1_score(all_labels, preds, average="micro")
            if f1 > best_f1_epoch:
                best_f1_epoch = f1
                best_alpha_epoch = alpha

        print(f"Epoch {epoch+1} - Val micro-F1: {best_f1_epoch:.4f} (alpha={best_alpha_epoch:.2f})")

        if best_f1_epoch > best_val_f1:
            best_val_f1 = best_f1_epoch
            best_alpha = best_alpha_epoch
            patience_counter = 0
            scibert.save_pretrained(f"{local_best_path}/scibert")
            bert.save_pretrained(f"{local_best_path}/bert")
            print(f"  New best saved (micro-F1 {best_val_f1:.4f}, alpha {best_alpha:.2f})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("  Early stopping triggered.")
                break

    # After all epochs for this LR
    if best_val_f1 > best_overall_f1:
        best_overall_f1 = best_val_f1
        best_lr = lr
        best_model_final_path = local_best_path
        best_alpha_final = best_alpha
    else:
        shutil.rmtree(local_best_path)

print(f"\n=== Best overall learning rate: {best_lr} ===")
print(f"=== Best co-trained models in {best_model_final_path} with val micro-F1 {best_overall_f1:.4f}, alpha={best_alpha_final:.2f} ===")


In [None]:
scibert = AutoModelForSequenceClassification.from_pretrained(f"{best_model_final_path}/scibert").to(device)
bert = AutoModelForSequenceClassification.from_pretrained(f"{best_model_final_path}/bert").to(device)
scibert.eval()
bert.eval()

preds, true_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        batch_gpu = {k: v.to(device) for k, v in batch.items()}
        out_sci = scibert(input_ids=batch_gpu["input_ids_sci"], attention_mask=batch_gpu["attention_mask_sci"])
        out_bert = bert(input_ids=batch_gpu["input_ids_bert"], attention_mask=batch_gpu["attention_mask_bert"])
        logits = best_alpha_final * out_sci.logits + (1 - best_alpha_final) * out_bert.logits
        preds.extend(torch.argmax(logits, dim=-1).cpu().numpy())
        true_labels.extend(batch_gpu["labels"].cpu().numpy())

report_dict = classification_report(true_labels, preds, output_dict=True)
report_df = pd.DataFrame(report_dict).transpose()
report_path = f"/content/drive/MyDrive/WoS/best_models_co_training_2/cotrain_{selected_dataset}_classification_report.csv"
report_df.to_csv(report_path, index=True)

print("\n=== Test Set Report (Co-trained Adaptive Ensemble) ===")
print(report_df)
print(f"\nClassification report saved to: {report_path}")
