# BERT Model Testing and Text Modification Experiments

This notebook loads the pre-trained BERT model and saved test data from `bert_finetuning_pipeline.ipynb` to conduct text modification experiments and accuracy comparisons.

## Workflow:
1. Load saved model and test data
2. Establish baseline accuracy
3. Apply text modifications to test data
4. Compare modified test accuracies with baseline
5. Analyze results and impacts

**Prerequisites**: Run `bert_finetuning_pipeline.ipynb` first to train and save the model.

## 1. Setup and Imports

In [1]:
# !pip install nltk
# Cell for installing new dependencies
# !pip install tensorflow tensorflow-hub nltk

In [6]:
import os

# ====================================
# CONFIGURATION - Path Variables
# ====================================
# Update these paths to match your environment

# Base directory for datasets
DATASET_BASE_PATH = "/home/jivnesh/Harshit_Surge/dataset/eval_dataset"

# Individual dataset paths
RAID_EVAL_PATH = os.path.join(DATASET_BASE_PATH, "raid_eval.csv")
M4_EVAL_PATH = os.path.join(DATASET_BASE_PATH, "m4_eval.csv")
CHEAT_EVAL_PATH = os.path.join(DATASET_BASE_PATH, "cheat_eval.csv")
HC3_EVAL_PATH = os.path.join(DATASET_BASE_PATH, "hc3_eval.csv")
MAGE_EVAL_PATH = os.path.join(DATASET_BASE_PATH, "mage_eval.csv")

# Model checkpoint paths
BEST_MODEL_CHECKPOINT = "best_model.pt"
BEST_SIMPLEBERT_CHECKPOINT = "best_simplebert.pt"

# Optional: Create a dictionary for easy dataset access
DATASET_PATHS = {
    'raid_eval': RAID_EVAL_PATH,
    'm4_eval': M4_EVAL_PATH,
    'cheat_eval': CHEAT_EVAL_PATH,
    'hc3_eval': HC3_EVAL_PATH,
    'mage_eval': MAGE_EVAL_PATH,
}

print("Configuration loaded:")
print(f"  Dataset base path: {DATASET_BASE_PATH}")
print(f"  Model checkpoint: {BEST_MODEL_CHECKPOINT}")
print(f"  Simple BERT checkpoint: {BEST_SIMPLEBERT_CHECKPOINT}")


Configuration loaded:
  Dataset base path: /home/jivnesh/Harshit_Surge/dataset/eval_dataset
  Model checkpoint: best_model.pt
  Simple BERT checkpoint: best_simplebert.pt


In [2]:
# Cell 1 - Imports and setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import time
import datetime
import random
import os
import re
import nltk
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoConfig
)
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# Set random seeds for reproducibility
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_val)

print(f"Using PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using PyTorch 2.8.0+cu128
CUDA available: True
GPU: NVIDIA A100 80GB PCIe
Using device: cuda:1


In [3]:
#--------------------------------
#  Transformer parameters
#--------------------------------
max_seq_length = 64
batch_size = 128

#--------------------------------
#  GAN-BERT specific parameters
#--------------------------------
# number of hidden layers in the generator, 
# each of the size of the output space
num_hidden_layers_g = 2; 
# number of hidden layers in the discriminator, 
# each of the size of the input space
num_hidden_layers_d = 2; 
# size of the generator's input noisy vectors
noise_size = 150
# dropout to be applied to discriminator's input vectors
out_dropout_rate = 0.2

# Replicate labeled data to balance poorly represented datasets, 
# e.g., less than 1% of labeled material
apply_balance = True

#--------------------------------
#  Optimization parameters
#--------------------------------
learning_rate_discriminator = 5e-5
learning_rate_generator = 5e-5
epsilon = 1e-8
num_train_epochs = 10
multi_gpu = True
# Scheduler
apply_scheduler = False
warmup_proportion = 0.1
# Print
print_each_n_step = 10
label_list = ["human", "ai"]
#--------------------------------
#  Adopted Tranformer model
#--------------------------------
# Since this version is compatible with Huggingface transformers, you can uncomment
# (or add) transformer models compatible with GAN

# model_name = "bert-base-cased"
#model_name = "bert-base-uncased"
#model_name = "roberta-base"
#model_name = "albert-base-v2"
#model_name = "xlm-roberta-base"
#model_name = "amazon/bort"

#--------------------------------
#  Retrieve the TREC QC Dataset
#--------------------------------
# ! git clone https://github.com/crux82/ganbert

In [4]:
import gc

# Clear GPU memory to prevent out-of-memory issues

# Clear PyTorch cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

# Force garbage collection
gc.collect()

print("GPU memory cleared successfully!")
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved(device) / 1024**2:.2f} MB")

GPU memory cleared successfully!
GPU memory allocated: 0.00 MB
GPU memory cached: 0.00 MB


## Load Trained GANBERT From Checkpoint (Evaluation Only)
The following cells reconstruct the trained transformer + generator + discriminator from a saved checkpoint (e.g. `best_model.pt`) and provide a helper to run evaluation using existing `evaluate_on_dataloader` or `evaluate` functions. Adjust `checkpoint_path` as needed.

In [5]:
#------------------------------
#   The Generator as in 
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Generator(nn.Module):
    def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):
        super(Generator, self).__init__()
        layers = []
        hidden_sizes = [noise_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        layers.append(nn.Linear(hidden_sizes[-1],output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, noise):
        output_rep = self.layers(noise)
        return output_rep

#------------------------------
#   The Discriminator
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

In [7]:
# Reusable evaluation function for any dataloader
from typing import Dict, Any

def evaluate_on_dataloader(eval_name: str, eval_dataloader) -> Dict[str, Any]:
    transformer.eval()
    discriminator.eval()
    generator.eval()

    all_preds = []
    all_labels_ids = []

    with torch.no_grad():
        for batch in eval_dataloader:
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)

            outputs = transformer(
                b_input_ids,
                attention_mask=b_input_mask,
                return_dict=True
            )
            last_hidden_state = outputs.last_hidden_state
            cls_embed = last_hidden_state[:, 0, :]
            mask_exp = b_input_mask.unsqueeze(-1).float()
            mean_embed = (last_hidden_state * mask_exp).sum(1) / mask_exp.sum(1)
            sent_rep = torch.cat([cls_embed, mean_embed], dim=1)

            _, logits, probs = discriminator(sent_rep)
            filtered_logits = logits[:, 0:-1]
            _, preds = torch.max(filtered_logits, 1)

            all_preds += preds.detach().cpu()
            all_labels_ids += b_labels.detach().cpu()

    all_preds = torch.stack(all_preds).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()

    acc = np.sum(all_preds == all_labels_ids) / len(all_preds)
    report = classification_report(all_labels_ids, all_preds, target_names=label_list, zero_division=0, output_dict=True)
    cm = confusion_matrix(all_labels_ids, all_preds)

    print(f"\n[{eval_name}] Accuracy: {acc:.3f}")
    print(f"[{eval_name}] Classification Report:")
    from pprint import pprint
    pprint(report)
    print(f"[{eval_name}] Confusion Matrix:\n{cm}")

    # Also print FPR explicitly
    fpr = np.sum((all_preds == 1) & (all_labels_ids == 0)) / max(1, np.sum(all_labels_ids == 0))
    print(f"[{eval_name}] False Positive Rate: {fpr:.3f}")

    return {
        'name': eval_name,
        'accuracy': float(acc),
        'classification_report': report,
        'confusion_matrix': cm.tolist(),
        'false_positive_rate': float(fpr),
    }


## Dataloaders

In [8]:
def generate_data_loader(df,batch_size=32, do_shuffle=False, balance_label_examples=False):
    """
    Build a DataLoader from a DataFrame with columns:
      - text (str)
      - label (int: 0=human, 1=ai)

    Returns batches of:
      (input_ids, attention_mask, label_ids, label_mask)

    label_mask is all ones (fully labeled dataset).
    """
    # Optional class balancing (simple minority oversampling)
    if balance_label_examples:
        class_counts = df['label'].value_counts()
        if len(class_counts) == 2:
            max_count = class_counts.max()
            dfs = []
            for lbl, cnt in class_counts.items():
                sub = df[df.label == lbl]
                if cnt < max_count:
                    reps = max_count - cnt
                    # oversample with replacement
                    extra = sub.sample(reps, replace=True, random_state=seed_val)
                    sub = pd.concat([sub, extra], ignore_index=True)
                dfs.append(sub)
            df = pd.concat(dfs, ignore_index=True).sample(frac=1, random_state=seed_val).reset_index(drop=True)

    texts = df['text'].tolist()
    labels = df['label'].tolist()

    input_ids = []
    attention_masks = []
    for t in texts:
        encoded = tokenizer.encode(
            t,
            add_special_tokens=True,
            max_length=max_seq_length,
            padding="max_length",
            truncation=True
        )
        input_ids.append(encoded)
        attention_masks.append([int(tok_id > 0) for tok_id in encoded])

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    label_ids = torch.tensor(labels, dtype=torch.long)
    label_mask_array = torch.ones(len(labels), dtype=torch.bool)

    dataset = TensorDataset(input_ids, attention_masks, label_ids, label_mask_array)

    sampler_cls = RandomSampler if do_shuffle else SequentialSampler
    try:
        bs = batch_size
    except NameError:
        bs = 32

    return DataLoader(dataset, sampler=sampler_cls(dataset), batch_size=bs)


In [9]:
import pandas as pd
raid_eval=pd.read_csv(RAID_EVAL_PATH)
raid_eval["label"]=raid_eval["models"].replace({"human":0,"ai":1})
m4_eval=pd.read_csv(M4_EVAL_PATH)
cheat_eval=pd.read_csv(CHEAT_EVAL_PATH)
hc3_eval=pd.read_csv(HC3_EVAL_PATH)
mage_eval=pd.read_csv(MAGE_EVAL_PATH)

  raid_eval["label"]=raid_eval["models"].replace({"human":0,"ai":1})


In [10]:
def evaluate_on_df(eval_name: str, eval_df) -> Dict[str, Any]:
    # create data loaders
    print(f"Evaluating on {eval_name} dataset with {len(eval_df)} samples.")
    dataloader=generate_data_loader(eval_df, batch_size=32, do_shuffle=False, balance_label_examples=False)
    return evaluate_on_dataloader(eval_name, dataloader)

In [12]:
# Load-and-evaluate snippet for GANBERT checkpoint
import torch, os
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig

# --- CONFIGURE ---
checkpoint_path = BEST_MODEL_CHECKPOINT  # UPDATE to your actual saved file
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
map_location = None if torch.cuda.is_available() else torch.device('cpu')

# --- Helper: DataParallel agnostic loader ---
def load_state_into_model(model, state_dict):
    """Load state dict into model agnostic to DataParallel wrapping."""
    if not state_dict:  # empty dict
        return model
    keys = list(state_dict.keys())
    if len(keys) == 0:
        model.load_state_dict(state_dict)
        return model
    has_module_prefix = keys[0].startswith("module.")
    model_state_keys = list(model.state_dict().keys())
    model_has_module = model_state_keys[0].startswith("module.")
    if has_module_prefix and not model_has_module:
        fixed = {k.replace("module.", ""): v for k, v in state_dict.items()}
        model.load_state_dict(fixed)
    elif (not has_module_prefix) and model_has_module:
        fixed = {("module." + k): v for k, v in state_dict.items()}
        model.load_state_dict(fixed)
    else:
        model.load_state_dict(state_dict)
    return model

if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
print(f"Loading checkpoint: {checkpoint_path} -> device={device}")
ckpt = torch.load(checkpoint_path, map_location=map_location)

saved_epoch = ckpt.get('epoch')
saved_test_acc = ckpt.get('test_accuracy')
label_map = ckpt.get('label_map')
saved_config = ckpt.get('config', {})
model_name=saved_config.get('model_name')
print({"epoch": saved_epoch, "test_acc": saved_test_acc, "config_keys": list(saved_config.keys())})
tokenizer = AutoTokenizer.from_pretrained(model_name)                                                                                          
# --- Reconstruct transformer ---
transformer = None
model_name = saved_config.get('model_name')
if model_name:
    try:
        print(f"Loading base transformer: {model_name}")
        transformer = AutoModel.from_pretrained(model_name)
    except Exception as e:
        print("HF load failed; manually instantiate transformer.", e)
else:
    print("No model_name in checkpoint config; provide transformer manually.")

# --- Reconstruct Generator / Discriminator (expect classes already defined in notebook) ---
noise_size = saved_config.get('noise_size')
hidden_size = saved_config.get('hidden_size')
hidden_levels_g = saved_config.get('hidden_levels_g')
hidden_levels_d = saved_config.get('hidden_levels_d')
num_hidden_layers_g = saved_config.get('num_hidden_layers_g')
num_hidden_layers_d = saved_config.get('num_hidden_layers_d')
print("Loaded architecture config:", {
    'noise_size': noise_size,
    'hidden_size': hidden_size,
    'hidden_levels_g': hidden_levels_g,
    'hidden_levels_d': hidden_levels_d,
    'num_hidden_layers_g': num_hidden_layers_g,
    'num_hidden_layers_d': num_hidden_layers_d,
})

try:
    Generator; Discriminator
except NameError:
    raise RuntimeError("Define Generator and Discriminator classes (from training notebook) before running this cell.")

# Build kwargs adaptively (covers common naming patterns)
gen_kwargs = {}
disc_kwargs = {}
if noise_size is not None: gen_kwargs['noise_size'] = noise_size
if hidden_size is not None:
    gen_kwargs['output_size'] = hidden_size  # training used output_size for generator final size
    disc_kwargs['input_size'] = hidden_size
if hidden_levels_g is not None: gen_kwargs['hidden_sizes'] = hidden_levels_g
if hidden_levels_d is not None: disc_kwargs['hidden_sizes'] = hidden_levels_d
if 'dropout_rate' in saved_config:  # pass through if recorded
    gen_kwargs['dropout_rate'] = saved_config['dropout_rate']
    disc_kwargs['dropout_rate'] = saved_config['dropout_rate']

print("Generator kwargs:", gen_kwargs)
print("Discriminator kwargs:", disc_kwargs)

generator = Generator(**gen_kwargs)
discriminator = Discriminator(num_labels=len(label_map) if label_map else 2, **disc_kwargs)

# --- Load state dicts ---
transformer_state = ckpt.get('transformer_state_dict') or ckpt.get('transformer')
generator_state = ckpt.get('generator_state_dict')
discriminator_state = ckpt.get('discriminator_state_dict') or ckpt.get('discriminator')

if transformer is not None and transformer_state is not None:
    load_state_into_model(transformer, transformer_state)
    print('Transformer weights loaded.')
else:
    print('Transformer weights missing or transformer not instantiated.')

if generator_state is not None:
    load_state_into_model(generator, generator_state)
    print('Generator weights loaded.')
else:
    print('No generator weights found.')

if discriminator_state is not None:
    load_state_into_model(discriminator, discriminator_state)
    print('Discriminator weights loaded.')
else:
    print('No discriminator weights found.')

# Move to device
generator.to(device).eval()
discriminator.to(device).eval()
if transformer is not None: transformer.to(device).eval()

print('Checkpoint restoration complete.')
if label_map: print('Label map:', label_map)
else: print('No label_map stored; using default order (0..N-1).')

# --- Evaluation helper using existing evaluate_on_dataloader / evaluate ---
def _call_evaluate(split_name, dataloader, **kwargs):
    if 'evaluate_on_dataloader' in globals():
        return evaluate_on_dataloader(split_name, dataloader, **kwargs)
    elif 'evaluate' in globals():
        return evaluate(split_name, dataloader, **kwargs)
    else:
        raise RuntimeError('Define evaluate_on_dataloader or evaluate before calling this.')

# Example (uncomment after you prepare a DataLoader named new_test_dataloader):
# results = _call_evaluate('New Dataset', new_test_dataloader)
# print(results)

print('Ready to evaluate: supply a DataLoader and call _call_evaluate().')

Loading checkpoint: best_model.pt -> device=cuda


{'epoch': 5, 'test_acc': 0.9595833333333333, 'config_keys': ['model_name', 'noise_size', 'hidden_size', 'hidden_levels_g', 'hidden_levels_d', 'num_hidden_layers_g', 'num_hidden_layers_d']}
Loading base transformer: bert-base-cased


2026-01-31 15:18:02.780155: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-31 15:18:02.799719: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769840282.819923 1767299 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769840282.826529 1767299 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769840282.843774 1767299 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Loaded architecture config: {'noise_size': 150, 'hidden_size': 1536, 'hidden_levels_g': [1536, 1536], 'hidden_levels_d': [1536, 1536], 'num_hidden_layers_g': 2, 'num_hidden_layers_d': 2}
Generator kwargs: {'noise_size': 150, 'output_size': 1536, 'hidden_sizes': [1536, 1536]}
Discriminator kwargs: {'input_size': 1536, 'hidden_sizes': [1536, 1536]}
Transformer weights loaded.
Generator weights loaded.
Discriminator weights loaded.
Checkpoint restoration complete.
Label map: {'human': 0, 'ai': 1}
Ready to evaluate: supply a DataLoader and call _call_evaluate().


In [13]:
results_pre = {
    'hc3': evaluate_on_df('hc3', hc3_eval),
    'mage_eval': evaluate_on_df('mage_eval', mage_eval),
    'm4_eval': evaluate_on_df('m4_eval', m4_eval),
    'cheat_eval': evaluate_on_df('cheat_eval', cheat_eval),
    'raid_eval': evaluate_on_df('raid_eval', raid_eval),
}

Evaluating on hc3 dataset with 10000 samples.



[hc3] Accuracy: 0.832
[hc3] Classification Report:
{'accuracy': 0.8324,
 'ai': {'f1-score': 0.8217779668226287,
        'precision': 0.8773841961852861,
        'recall': 0.7728,
        'support': 5000.0},
 'human': {'f1-score': 0.8418271045677614,
           'precision': 0.7969978556111508,
           'recall': 0.892,
           'support': 5000.0},
 'macro avg': {'f1-score': 0.831802535695195,
               'precision': 0.8371910258982185,
               'recall': 0.8324,
               'support': 10000.0},
 'weighted avg': {'f1-score': 0.8318025356951951,
                  'precision': 0.8371910258982185,
                  'recall': 0.8324,
                  'support': 10000.0}}
[hc3] Confusion Matrix:
[[4460  540]
 [1136 3864]]
[hc3] False Positive Rate: 0.108
Evaluating on mage_eval dataset with 10000 samples.

[mage_eval] Accuracy: 0.563
[mage_eval] Classification Report:
{'accuracy': 0.5628,
 'ai': {'f1-score': 0.3895559899469422,
        'precision': 0.6452358926919519,
     

In [14]:
for key in results_pre:
    print(f"acc for {key} is {results_pre[key]['accuracy']}")

acc for hc3 is 0.8324
acc for mage_eval is 0.5628
acc for m4_eval is 0.6037
acc for cheat_eval is 0.8584
acc for raid_eval is 0.9595833333333333


In [15]:
# Safe loader + evaluator for your "best_simplebert.pt" checkpoints
import os
import torch
import torch.nn as nn
from transformers import AutoModel
import inspect

class CheckpointAdapter:
    """
    Loads checkpoints saved with your baseline training loop and provides
    an isolated TransformerClassifier instance ready for evaluation on a single GPU.
    """
    def __init__(self, device=None):
        if device is None:
            self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        print(f"[Adapter] device -> {self.device}")

    @staticmethod
    def _fix_state_dict(state_dict):
        if not state_dict:
            return state_dict
        keys = list(state_dict.keys())
        if not keys:
            return state_dict
        has_module = keys[0].startswith("module.")
        if has_module:
            return {k.replace("module.", ""): v for k, v in state_dict.items()}
        return state_dict

    class TransformerClassifier(nn.Module):
        def __init__(self, transformer_model, hidden_size, num_labels, dropout_rate=0.1):
            super().__init__()
            self.transformer = transformer_model
            self.dropout = nn.Dropout(dropout_rate)
            self.classifier = nn.Linear(hidden_size, num_labels)

        def forward(self, input_ids, attention_mask=None):
            outputs = self.transformer(input_ids, attention_mask=attention_mask)
            hidden_states = outputs.last_hidden_state
            if attention_mask is not None:
                mask_expanded = attention_mask.unsqueeze(-1).expand_as(hidden_states).float()
                sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
                sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
                pooled = sum_embeddings / sum_mask
            else:
                pooled = hidden_states.mean(dim=1)
            pooled = self.dropout(pooled)
            return self.classifier(pooled)

    def load_checkpoint(self, checkpoint_path, transformer_factory=None, strict_load=True):
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location="cpu")
        meta = {
            'epoch': ckpt.get('epoch'),
            'accuracy': ckpt.get('accuracy', ckpt.get('test_accuracy')),
            'seed': ckpt.get('seed'),
            'model_name': ckpt.get('model_name') or (ckpt.get('config', {}) or {}).get('model_name'),
            'label_map': ckpt.get('label_map'),
            'raw_ckpt': ckpt
        }
        print(f"[Adapter] Checkpoint loaded. epoch={meta['epoch']}, acc={meta['accuracy']}, model_name={meta['model_name']}")

        # Figure out what was saved
        full_classifier_state = ckpt.get('state_dict')
        transformer_only_state = ckpt.get('transformer_state_dict')

        # Instantiate transformer
        if transformer_factory is not None:
            try:
                sig = inspect.signature(transformer_factory)
                transformer = transformer_factory() if len(sig.parameters) == 0 else transformer_factory(meta['model_name'])
            except Exception as e:
                raise RuntimeError(f"transformer_factory failed: {e}")
        else:
            if meta['model_name'] is None:
                raise RuntimeError("No model_name in checkpoint and no transformer_factory provided.")
            print(f"[Adapter] instantiating AutoModel.from_pretrained('{meta['model_name']}')")
            transformer = AutoModel.from_pretrained(meta['model_name'])

        # Load transformer weights
        if transformer_only_state is not None:
            missing, unexpected = transformer.load_state_dict(transformer_only_state, strict=False)
            if missing or unexpected:
                print(f"[Adapter] (info) transformer load missing={len(missing)} unexpected={len(unexpected)}")

        hidden_size = getattr(transformer.config, "hidden_size", None)
        if hidden_size is None:
            raise RuntimeError("Transformer config has no hidden_size.")

        # Infer num_labels
        num_labels = None
        if meta['label_map']:
            try:
                num_labels = len(meta['label_map'])
            except Exception:
                pass
        if num_labels is None and full_classifier_state:
            for k in ['classifier.weight', 'module.classifier.weight']:
                if k in full_classifier_state:
                    num_labels = full_classifier_state[k].shape[0]
                    break
        if num_labels is None:
            raise RuntimeError("Could not infer num_labels (need label_map or classifier weights).")

        model = self.TransformerClassifier(transformer, hidden_size, num_labels).to(self.device)

        # If a full model (transformer + classifier) state dict was saved, load it
        if full_classifier_state:
            fixed = self._fix_state_dict(full_classifier_state)
            try:
                model.load_state_dict(fixed, strict=strict_load)
                print("[Adapter] full model state loaded.")
            except RuntimeError as e:
                print(f"[Adapter] strict load failed: {e}")
                if strict_load:
                    print("[Adapter] retrying with strict=False")
                    model.load_state_dict(fixed, strict=False)

        model.eval()
        return model, meta

    def evaluate_model(self, model, dataloader, eval_name="eval", label_list=None):
        import time, numpy as np
        from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
        if label_list is None and 'label_list' in globals():
            label_list = globals()['label_list']
        model.to(self.device).eval()
        all_preds, all_labels = [], []
        t0 = time.time()
        with torch.no_grad():
            for batch in dataloader:
                # Expect tuple: (input_ids, attention_mask, labels, *optional)
                input_ids = batch[0].to(self.device)
                attention_mask = batch[1].to(self.device) if len(batch) > 1 else None
                labels = batch[2].to(self.device) if len(batch) > 2 else None
                logits = model(input_ids, attention_mask=attention_mask)
                preds = torch.argmax(logits, dim=-1)
                if labels is not None:
                    all_labels.extend(labels.cpu().tolist())
                all_preds.extend(preds.cpu().tolist())
        if all_labels:
            acc = accuracy_score(all_labels, all_preds)
            if label_list:
                report = classification_report(all_labels, all_preds, target_names=label_list, zero_division=0, output_dict=True)
            else:
                report = classification_report(all_labels, all_preds, zero_division=0, output_dict=True)
            cm = confusion_matrix(all_labels, all_preds).tolist()
            # Assume positive class = 1 if binary
            if len(set(all_labels)) == 2 and 1 in set(all_labels):
                denom = max(1, sum(1 for x in all_labels if x == 0))
                fpr = sum(1 for p, y in zip(all_preds, all_labels) if p == 1 and y == 0) / denom
            else:
                fpr = None
        else:
            acc, report, cm, fpr = None, {}, [], None
        print(f"[{eval_name}] done in {time.time()-t0:.1f}s acc={acc}")
        return {
            "name": eval_name,
            "accuracy": acc,
            "classification_report": report,
            "confusion_matrix": cm,
            "false_positive_rate": fpr
        }

    def evaluate_on_df(self, model, eval_df, eval_name="df_eval", batch_size=32):
        if 'generate_data_loader' not in globals():
            raise RuntimeError("generate_data_loader not defined in this notebook.")
        dataloader = generate_data_loader(eval_df, batch_size=batch_size, do_shuffle=False, balance_label_examples=False)
        return self.evaluate_model(model, dataloader, eval_name=eval_name)

# --------------------------
# Example usage snippet
# --------------------------
# NOTE: change checkpoint_path to the path of your saved best_simplebert.pt
checkpoint_path = BEST_SIMPLEBERT_CHECKPOINT  # <-- update if necessary
adapter = CheckpointAdapter(device=None)  # will pick cuda:1 if available else cpu

# If your training saved model_name in the checkpoint and you want the loader to instantiate HF model:
model, metadata = adapter.load_checkpoint(checkpoint_path)

# If you want to provide a custom transformer factory (for example, set local cache or different init):
# model, metadata = adapter.load_checkpoint(checkpoint_path, transformer_factory=lambda name: AutoModel.from_pretrained(name, local_files_only=False))

# Now evaluate:
# Make sure you create new_test_dataloader in this notebook with the exact expected batch shape:
# (input_ids, attention_mask, labels, label_mask)
# Example: results = adapter.evaluate_model(model, new_test_dataloader)
# It will call the existing `evaluate` or `evaluate_on_dataloader` function if present.
#
# Example call (uncomment when you have dataloader ready):
# results = adapter.evaluate_model(model, new_test_dataloader)
# print("Evaluation results:", results)


[Adapter] device -> cuda:1


[Adapter] Checkpoint loaded. epoch=0, acc=0.0, model_name=bert-base-cased
[Adapter] instantiating AutoModel.from_pretrained('bert-base-cased')
[Adapter] full model state loaded.


In [16]:
# Re-evaluate datasets using the baseline adapter model instead of GANBERT discriminator
results_bert = {
    'hc3': adapter.evaluate_on_df(model, hc3_eval, eval_name='hc3'),
    'mage_eval': adapter.evaluate_on_df(model, mage_eval, eval_name='mage_eval'),
    'm4_eval': adapter.evaluate_on_df(model, m4_eval, eval_name='m4_eval'),
    'cheat_eval': adapter.evaluate_on_df(model, cheat_eval, eval_name='cheat_eval'),
    'raid_eval': adapter.evaluate_on_df(model, raid_eval, eval_name='raid_eval'),
}

[hc3] done in 8.4s acc=0.8579
[mage_eval] done in 8.3s acc=0.5725
[m4_eval] done in 8.3s acc=0.6444
[cheat_eval] done in 8.3s acc=0.8276
[raid_eval] done in 10.0s acc=0.9414166666666667


In [17]:
import os
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel

# --- Generator / Discriminator (same as training definition) ---
class Generator(nn.Module):
    def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):
        super(Generator, self).__init__()
        layers = []
        hidden_sizes = [noise_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
                           nn.LeakyReLU(0.2, inplace=True),
                           nn.Dropout(dropout_rate)])
        layers.append(nn.Linear(hidden_sizes[-1], output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, noise):
        return self.layers(noise)

class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
                           nn.LeakyReLU(0.2, inplace=True),
                           nn.Dropout(dropout_rate)])
        self.layers = nn.Sequential(*layers)
        self.logit = nn.Linear(hidden_sizes[-1], num_labels+1)  # +1 for fake class
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

# --- Helper to load state dict agnostic to DataParallel/`module.` prefixes ---
def load_state_into_model(model, state_dict):
    if not state_dict:
        return model
    keys = list(state_dict.keys())
    if len(keys) == 0:
        model.load_state_dict(state_dict)
        return model
    has_module_prefix = keys[0].startswith("module.")
    model_state_keys = list(model.state_dict().keys())
    model_has_module = model_state_keys[0].startswith("module.")
    if has_module_prefix and not model_has_module:
        fixed = {k.replace("module.", ""): v for k, v in state_dict.items()}
        model.load_state_dict(fixed)
    elif (not has_module_prefix) and model_has_module:
        fixed = {("module." + k): v for k, v in state_dict.items()}
        model.load_state_dict(fixed)
    else:
        model.load_state_dict(state_dict)
    return model

# --- Main wrapper class ---
class GANBERTWrapper:
    """
    Simple wrapper to load a GAN-BERT checkpoint and predict AI probability for single texts.
    Usage:
        w = GANBERTWrapper("best_model.pt", device="cuda" or "cpu", max_seq_length=64)
        p = w.predict_proba("some text here")  # float in [0,1] probability that text is 'ai'
    """
    def __init__(self, checkpoint_path: str, device: str = None, max_seq_length: int = 64):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.max_seq_length = max_seq_length

        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        # load checkpoint
        ckpt = torch.load(checkpoint_path, map_location=self.device if self.device.type == "cpu" else None)

        self.ckpt = ckpt  # keep for inspection if needed
        saved_config = ckpt.get('config', {})
        self.label_map = ckpt.get('label_map', None)  # may be dict or list
        model_name = saved_config.get('model_name', None)

        # tokenizer + transformer
        if model_name is None:
            raise RuntimeError("Checkpoint doesn't contain 'model_name' in config. "
                               "Either add model_name to checkpoint config or modify wrapper to provide it.")
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        # instantiate transformer
        try:
            transformer = AutoModel.from_pretrained(self.model_name)
        except Exception as e:
            raise RuntimeError(f"Failed to instantiate transformer from '{self.model_name}': {e}")
        # load transformer weights if present
        transformer_state = ckpt.get('transformer_state_dict') or ckpt.get('transformer')
        if transformer_state is not None:
            load_state_into_model(transformer, transformer_state)
        transformer.to(self.device).eval()
        self.transformer = transformer

        # build generator/discriminator kwargs from saved_config
        gen_kwargs = {}
        disc_kwargs = {}
        noise_size = saved_config.get('noise_size')
        hidden_size = saved_config.get('hidden_size')
        hidden_levels_g = saved_config.get('hidden_levels_g')
        hidden_levels_d = saved_config.get('hidden_levels_d')
        if noise_size is not None: gen_kwargs['noise_size'] = noise_size
        if hidden_size is not None:
            gen_kwargs['output_size'] = hidden_size
            disc_kwargs['input_size'] = hidden_size
        if hidden_levels_g is not None: gen_kwargs['hidden_sizes'] = hidden_levels_g
        if hidden_levels_d is not None: disc_kwargs['hidden_sizes'] = hidden_levels_d
        if 'dropout_rate' in saved_config:
            gen_kwargs['dropout_rate'] = saved_config['dropout_rate']
            disc_kwargs['dropout_rate'] = saved_config['dropout_rate']

        # instantiate models
        self.generator = Generator(**gen_kwargs)
        self.discriminator = Discriminator(num_labels=(len(self.label_map) if self.label_map else 2), **disc_kwargs)

        # load generator/discriminator weights if present
        generator_state = ckpt.get('generator_state_dict')
        discriminator_state = ckpt.get('discriminator_state_dict') or ckpt.get('discriminator')
        if generator_state is not None:
            load_state_into_model(self.generator, generator_state)
        if discriminator_state is not None:
            load_state_into_model(self.discriminator, discriminator_state)

        # move to device and set eval mode
        self.generator.to(self.device).eval()
        self.discriminator.to(self.device).eval()

        # figure ai label index
        self.ai_label_index = self._resolve_ai_index()

    def _resolve_ai_index(self):
        """
        Resolve the index for the 'ai' label in the label map or default to 1.
        Supports label_map being dict {'human':0,'ai':1} or list ['human','ai'].
        """
        if self.label_map is None:
            return 1
        if isinstance(self.label_map, dict):
            # label_map maps label->index or index->label? try both
            if 'ai' in self.label_map:
                return int(self.label_map['ai'])
            # maybe it's index->label
            for k, v in self.label_map.items():
                if isinstance(k, (int, str)) and str(v).lower() == 'ai':
                    try:
                        return int(k)
                    except:
                        pass
        if isinstance(self.label_map, (list, tuple)):
            for i, v in enumerate(self.label_map):
                if str(v).lower() == 'ai':
                    return i
        # fallback
        return 1

    def predict_proba(self, text: str) -> float:
        """
        Predict probability that `text` is AI-generated.
        Returns float in [0,1] representing P(ai).
        """
        self.transformer.eval()
        self.discriminator.eval()

        encoded = self.tokenizer(
            text,
            max_length=self.max_seq_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        input_ids = encoded['input_ids'].to(self.device)
        attention_mask = encoded['attention_mask'].to(self.device)

        with torch.no_grad():
            outputs = self.transformer(
                input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            last_hidden_state = outputs.last_hidden_state  # (1, seq_len, hidden)
            cls_embed = last_hidden_state[:, 0, :]
            mask_exp = attention_mask.unsqueeze(-1).float()
            mean_embed = (last_hidden_state * mask_exp).sum(1) / mask_exp.sum(1)
            sent_rep = torch.cat([cls_embed, mean_embed], dim=1)  # same concat used at training

            _, logits, probs = self.discriminator(sent_rep)  # probs shape (1, num_labels+1)
            # remove last "fake/real" column: model used last column for fake/real in training
            probs_real = probs[:, :-1]  # shape (1, num_labels)
            # ensure ai index within range
            idx = int(self.ai_label_index)
            if idx < 0 or idx >= probs_real.shape[1]:
                # fallback: try to find 'ai' by ordering assumption: human=0 ai=1
                idx = 1 if probs_real.shape[1] > 1 else 0
            ai_prob = probs_real[0, idx].item()
        return float(ai_prob)

# --- Convenience function to evaluate many models for the same text ---
def predict_from_many_models(checkpoint_paths, text, device=None, max_seq_length=64):
    """
    checkpoint_paths: list of checkpoint file paths
    text: string to evaluate
    Returns: dict {checkpoint_path: probability}
    """
    results = {}
    for cp in checkpoint_paths:
        w = GANBERTWrapper(cp, device=device, max_seq_length=max_seq_length)
        results[cp] = w.predict_proba(text)
    return results

# -------------------------
# Example usage:
# -------------------------
if __name__ == "__main__":
    # single model
    cp = "best_model.pt"   # replace with your path
    w = GANBERTWrapper(cp, device="cuda:1" if torch.cuda.is_available() else "cpu", max_seq_length=64)
    text = "This is a sample text to test whether it's ai generated."
    print("AI probability:", w.predict_proba(text))

    # multiple models
    # cps = ["model_a.pt", "model_b.pt", ...]
    # probs = predict_from_many_models(cps, text)
    # print(probs)


AI probability: 2.3100389512364927e-07


In [18]:
# Safe loader + evaluator for your "best_simplebert.pt" checkpoints
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import inspect
from typing import List, Optional, Tuple

class CheckpointAdapter:
    """
    Loads checkpoints saved with your baseline training loop and provides
    an isolated TransformerClassifier instance ready for evaluation on a single GPU.
    This class now also contains a convenience 'InferenceWrapper' that loads a tokenizer
    and provides predict_proba / predict_proba_batch methods.
    """
    def __init__(self, device: Optional[str] = None):
        if device is None:
            self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        print(f"[Adapter] device -> {self.device}")

    @staticmethod
    def _fix_state_dict(state_dict):
        if not state_dict:
            return state_dict
        keys = list(state_dict.keys())
        if not keys:
            return state_dict
        has_module = keys[0].startswith("module.")
        if has_module:
            return {k.replace("module.", ""): v for k, v in state_dict.items()}
        return state_dict

    class TransformerClassifier(nn.Module):
        def __init__(self, transformer_model, hidden_size, num_labels, dropout_rate=0.1):
            super().__init__()
            self.transformer = transformer_model
            self.dropout = nn.Dropout(dropout_rate)
            self.classifier = nn.Linear(hidden_size, num_labels)

        def forward(self, input_ids, attention_mask=None):
            outputs = self.transformer(input_ids, attention_mask=attention_mask, return_dict=True)
            hidden_states = outputs.last_hidden_state
            if attention_mask is not None:
                mask_expanded = attention_mask.unsqueeze(-1).expand_as(hidden_states).float()
                sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
                sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
                pooled = sum_embeddings / sum_mask
            else:
                pooled = hidden_states.mean(dim=1)
            pooled = self.dropout(pooled)
            return self.classifier(pooled)

    def load_checkpoint(self, checkpoint_path: str, transformer_factory=None, strict_load: bool = True) -> Tuple[nn.Module, dict]:
        """
        Load checkpoint and return (model, meta).
        This is your original loader; unchanged except returning model+meta for further wrapping.
        """
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location="cpu")
        meta = {
            'epoch': ckpt.get('epoch'),
            'accuracy': ckpt.get('accuracy', ckpt.get('test_accuracy')),
            'seed': ckpt.get('seed'),
            'model_name': ckpt.get('model_name') or (ckpt.get('config', {}) or {}).get('model_name'),
            'label_map': ckpt.get('label_map'),
            'raw_ckpt': ckpt
        }
        print(f"[Adapter] Checkpoint loaded. epoch={meta['epoch']}, acc={meta['accuracy']}, model_name={meta['model_name']}")

        # Figure out what was saved
        full_classifier_state = ckpt.get('state_dict')
        transformer_only_state = ckpt.get('transformer_state_dict')

        # Instantiate transformer
        if transformer_factory is not None:
            try:
                sig = inspect.signature(transformer_factory)
                transformer = transformer_factory() if len(sig.parameters) == 0 else transformer_factory(meta['model_name'])
            except Exception as e:
                raise RuntimeError(f"transformer_factory failed: {e}")
        else:
            if meta['model_name'] is None:
                raise RuntimeError("No model_name in checkpoint and no transformer_factory provided.")
            print(f"[Adapter] instantiating AutoModel.from_pretrained('{meta['model_name']}')")
            transformer = AutoModel.from_pretrained(meta['model_name'])

        # Load transformer weights if present
        if transformer_only_state is not None:
            fixed_t = self._fix_state_dict(transformer_only_state)
            missing, unexpected = transformer.load_state_dict(fixed_t, strict=False)
            if missing or unexpected:
                print(f"[Adapter] (info) transformer load missing={len(missing)} unexpected={len(unexpected)}")

        hidden_size = getattr(transformer.config, "hidden_size", None)
        if hidden_size is None:
            raise RuntimeError("Transformer config has no hidden_size.")

        # Infer num_labels
        num_labels = None
        if meta['label_map']:
            try:
                num_labels = len(meta['label_map'])
            except Exception:
                pass
        if num_labels is None and full_classifier_state:
            for k in ['classifier.weight', 'module.classifier.weight']:
                if k in full_classifier_state:
                    num_labels = full_classifier_state[k].shape[0]
                    break
        if num_labels is None:
            raise RuntimeError("Could not infer num_labels (need label_map or classifier weights).")

        model = self.TransformerClassifier(transformer, hidden_size, num_labels).to(self.device)

        # If a full model (transformer + classifier) state dict was saved, load it
        if full_classifier_state:
            fixed = self._fix_state_dict(full_classifier_state)
            try:
                model.load_state_dict(fixed, strict=strict_load)
                print("[Adapter] full model state loaded.")
            except RuntimeError as e:
                print(f"[Adapter] strict load failed: {e}")
                if strict_load:
                    print("[Adapter] retrying with strict=False")
                    model.load_state_dict(fixed, strict=False)

        model.eval()
        return model, meta

    def evaluate_model(self, model, dataloader, eval_name="eval", label_list=None):
        import time, numpy as np
        from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
        if label_list is None and 'label_list' in globals():
            label_list = globals()['label_list']
        model.to(self.device).eval()
        all_preds, all_labels = [], []
        t0 = time.time()
        with torch.no_grad():
            for batch in dataloader:
                # Expect tuple: (input_ids, attention_mask, labels, *optional)
                input_ids = batch[0].to(self.device)
                attention_mask = batch[1].to(self.device) if len(batch) > 1 else None
                labels = batch[2].to(self.device) if len(batch) > 2 else None
                logits = model(input_ids, attention_mask=attention_mask)
                preds = torch.argmax(logits, dim=-1)
                if labels is not None:
                    all_labels.extend(labels.cpu().tolist())
                all_preds.extend(preds.cpu().tolist())
        if all_labels:
            acc = accuracy_score(all_labels, all_preds)
            if label_list:
                report = classification_report(all_labels, all_preds, target_names=label_list, zero_division=0, output_dict=True)
            else:
                report = classification_report(all_labels, all_preds, zero_division=0, output_dict=True)
            cm = confusion_matrix(all_labels, all_preds).tolist()
            # Assume positive class = 1 if binary
            if len(set(all_labels)) == 2 and 1 in set(all_labels):
                denom = max(1, sum(1 for x in all_labels if x == 0))
                fpr = sum(1 for p, y in zip(all_preds, all_labels) if p == 1 and y == 0) / denom
            else:
                fpr = None
        else:
            acc, report, cm, fpr = None, {}, [], None
        print(f"[{eval_name}] done in {time.time()-t0:.1f}s acc={acc}")
        return {
            "name": eval_name,
            "accuracy": acc,
            "classification_report": report,
            "confusion_matrix": cm,
            "false_positive_rate": fpr
        }

    def evaluate_on_df(self, model, eval_df, eval_name="df_eval", batch_size=32):
        if 'generate_data_loader' not in globals():
            raise RuntimeError("generate_data_loader not defined in this notebook.")
        dataloader = generate_data_loader(eval_df, batch_size=batch_size, do_shuffle=False, balance_label_examples=False)
        return self.evaluate_model(model, dataloader, eval_name=eval_name)

    # ---- New: inference wrapper builder ----
    class InferenceWrapper:
        """
        Lightweight wrapper returned by adapter.build_inference_wrapper(...)
        provides predict_proba & predict_proba_batch.
        """
        def __init__(self, model: nn.Module, meta: dict, tokenizer: AutoTokenizer, device: torch.device, max_seq_length: int = 128):
            self.model = model.to(device)
            self.meta = meta
            self.tokenizer = tokenizer
            self.device = device
            self.max_seq_length = max_seq_length
            self.model.eval()
            self.ai_label_index = self._resolve_ai_index(meta.get('label_map'))

        def _resolve_ai_index(self, label_map):
            """
            Resolve ai label index from label_map which can be dict or list.
            If not present, fallback to index 1 (assuming [human, ai]).
            """
            if not label_map:
                return 1
            if isinstance(label_map, dict):
                # prefer 'ai' key if maps name->index
                if 'ai' in label_map:
                    return int(label_map['ai'])
                # otherwise try detect index->label
                for k, v in label_map.items():
                    if str(v).lower() == 'ai':
                        try:
                            return int(k)
                        except:
                            pass
            if isinstance(label_map, (list, tuple)):
                for i, v in enumerate(label_map):
                    if str(v).lower() == 'ai':
                        return i
            # fallback
            return 1

        def predict_proba(self, text: str) -> float:
            """
            Single-string inference. Returns probability of 'ai' label in [0,1].
            """
            self.model.eval()
            encoded = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_seq_length,
                return_tensors='pt'
            )
            input_ids = encoded['input_ids'].to(self.device)
            attention_mask = encoded['attention_mask'].to(self.device) if 'attention_mask' in encoded else None

            with torch.no_grad():
                logits = self.model(input_ids, attention_mask=attention_mask)  # shape (1, num_labels)
                probs = F.softmax(logits, dim=-1)  # (1, num_labels)
                idx = int(self.ai_label_index)
                if idx < 0 or idx >= probs.shape[1]:
                    # fallback to index 1 if available
                    idx = 1 if probs.shape[1] > 1 else 0
                ai_prob = probs[0, idx].item()
            return float(ai_prob)

        def predict_proba_batch(self, texts: List[str], batch_size: int = 32) -> List[float]:
            """
            Batch inference, returns list of probabilities in same order as texts.
            """
            results = []
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i+batch_size]
                encoded = self.tokenizer(
                    batch_texts,
                    truncation=True,
                    padding='max_length',
                    max_length=self.max_seq_length,
                    return_tensors='pt'
                )
                input_ids = encoded['input_ids'].to(self.device)
                attention_mask = encoded['attention_mask'].to(self.device) if 'attention_mask' in encoded else None
                with torch.no_grad():
                    logits = self.model(input_ids, attention_mask=attention_mask)  # (B, num_labels)
                    probs = F.softmax(logits, dim=-1)  # (B, num_labels)
                    idx = int(self.ai_label_index)
                    if idx < 0 or idx >= probs.shape[1]:
                        idx = 1 if probs.shape[1] > 1 else 0
                    batch_probs = probs[:, idx].detach().cpu().tolist()
                    results.extend(batch_probs)
            return [float(x) for x in results]

    def build_inference_wrapper(self,
                                checkpoint_path: str,
                                transformer_factory=None,
                                strict_load: bool = True,
                                tokenizer: Optional[AutoTokenizer] = None,
                                max_seq_length: int = 128) -> 'CheckpointAdapter.InferenceWrapper':
        """
        Convenience method:
        - loads checkpoint (model + meta)
        - instantiates tokenizer if not provided (AutoTokenizer.from_pretrained(meta['model_name']))
        - returns InferenceWrapper with predict_proba / predict_proba_batch
        """
        model, meta = self.load_checkpoint(checkpoint_path, transformer_factory=transformer_factory, strict_load=strict_load)

        if tokenizer is None:
            model_name = meta.get('model_name')
            if not model_name:
                raise RuntimeError("Checkpoint lacks model_name; please supply `tokenizer` or `transformer_factory`.")
            tokenizer = AutoTokenizer.from_pretrained(model_name)

        wrapper = CheckpointAdapter.InferenceWrapper(model=model, meta=meta, tokenizer=tokenizer, device=self.device, max_seq_length=max_seq_length)
        return wrapper

# --------------------------
# Example usage:
# --------------------------
if __name__ == "__main__":
    checkpoint_path = "best_simplebert.pt"   # <-- update if necessary
    adapter = CheckpointAdapter(device=None)  # will pick cuda:1 if available else cpu

    # Build inference wrapper (loads tokenizer automatically using model_name from checkpoint)
    inf = adapter.build_inference_wrapper(checkpoint_path, max_seq_length=128)

    # Single text:
    sample_text = "This sample sentence tests whether the model thinks it's AI-written."
    print("AI prob:", inf.predict_proba(sample_text))

    # Batch
    texts = [
        "This looks like an AI generated sentence.",
        "I went to the store and bought some apples.",
        "In conclusion, this paragraph was likely produced by a transformer."
    ]
    probs = inf.predict_proba_batch(texts, batch_size=2)
    for t, p in zip(texts, probs):
        print(f"{p:.4f}  -  {t}")


[Adapter] device -> cuda:1
[Adapter] Checkpoint loaded. epoch=0, acc=0.0, model_name=bert-base-cased
[Adapter] instantiating AutoModel.from_pretrained('bert-base-cased')


[Adapter] full model state loaded.
AI prob: 0.19196173548698425
0.2858  -  This looks like an AI generated sentence.
0.7272  -  I went to the store and bought some apples.
0.2458  -  In conclusion, this paragraph was likely produced by a transformer.
