In [2]:
!pip install hf_xet torch torchvision transformers tqdm requests datasets accelerate bitsandbytes tensorboard torch-tb-profiler openai anthropic google-generativeai


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
# --- Safe flags for Apple-silicon ---
import os, platform
os.environ["TOKENIZERS_PARALLELISM"]           = "false"   # avoid fork-after-tokenizer bug
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.8"     # leave 10 % headroom, prevents sudden kills
os.environ["PYTORCH_MPS_LOW_WATERMARK_RATIO"] = "0.5"
os.environ["FLASH_ATTENTION_FORCE_DISABLE"]    = "1"       # disable Flash-Attn v2 path

In [4]:
import os
import json
import torch
import platform
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, logging

from torch.utils.tensorboard import SummaryWriter
tb_writer = SummaryWriter("runs/halueval_llama")

# Reduce verbosity of transformers
logging.set_verbosity_error()

# Check if GPU is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}") # change this to foundry gpu if needed

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


In [5]:
import torch, os, platform, psutil, time
print("Torch:", torch.__version__, "  Free RAM:", psutil.virtual_memory().available/1e9, "GB")
print("MPS cap:", torch.backends.mps.is_available())

Torch: 2.7.0   Free RAM: 4.542480384 GB
MPS cap: True


In [6]:
print(platform.platform(), torch.__version__)
print("MPS available:", torch.backends.mps.is_available())

macOS-13.4-arm64-arm-64bit 2.7.0
MPS available: True


In [7]:
# Define the base model using Llama from Hugging Face
class LlamaBaseNet(nn.Module):
    def __init__(self, model_name, num_classes=2):
        super().__init__()
        # Load Llama model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(model_name)
        # self.backbone = self.backbone.half()           # fp16
        self.backbone.gradient_checkpointing_enable()  # save RAM
        
        # If the tokenizer doesn't have a padding token, set it
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Get hidden size from config
        self.hidden_size = self.backbone.config.hidden_size
        
        # Classification head
        self.classifier = nn.Linear(self.hidden_size, num_classes)

    def forward(self, texts):
        # Tokenize and move to device
        if isinstance(texts, torch.Tensor):
            # If input is already tokenized
            inputs = {'input_ids': texts}
        else:
            # If input is raw text
            inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
        
        inputs = {k: v.to(self.classifier.weight.device) for k, v in inputs.items()}
        
        # Get model outputs
        with torch.no_grad():  # Don't compute gradients for the backbone
            outputs = self.backbone(**inputs)
            
        # Use the last hidden state of the last token for classification
        last_hidden_states = outputs.last_hidden_state
        sequence_lengths = torch.ne(inputs['input_ids'], self.tokenizer.pad_token_id).sum(-1) - 1
        batch_size = last_hidden_states.shape[0]
        
        # Get the hidden state for the last token in each sequence
        features = last_hidden_states[torch.arange(batch_size), sequence_lengths]
        
        # Apply classifier
        logits = self.classifier(features)
        
        return logits, features

In [8]:
# hugging face auth

from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()
hf_token = os.getenv("HUGGING_FACE_KEY")
login(token=hf_token)

In [9]:
# Load HaluEval dataset from Hugging Face
from datasets import load_dataset

def prepare_halueval_data_from_hf():
    """Load HaluEval dataset from Hugging Face"""
    print("Loading HaluEval dataset from Hugging Face...")
    
    # Create output directory
    output_dir = "data/halueval"
    os.makedirs(output_dir, exist_ok=True)
    
    # Process each split
    categories = ["qa", "dialogue", "summarization", "general"]
    
    # Prepare train and test sets
    for category in categories:
        print(f"Loading {category} dataset...")
        # Load the dataset for this category
        dataset = load_dataset("pminervini/HaluEval", category)
        
        # The dataset has a 'data' split containing all examples
        data = dataset['data']
        
        # Split into train/test (80/20 split)
        splits = data.train_test_split(test_size=0.2, seed=42)
        
        # Save as jsonl
        with open(f"{output_dir}/{category}_train.jsonl", 'w', encoding='utf-8') as f:
            for item in splits['train']:
                formatted_item = {
                    'question': item.get('instruction', ''),
                    'response': item.get('output', ''),
                    'is_hallucination': 1 if item.get('label') == 'hallucinated' else 0
                }
                f.write(json.dumps(formatted_item) + '\n')
        
        with open(f"{output_dir}/{category}_test.jsonl", 'w', encoding='utf-8') as f:
            for item in splits['test']:
                formatted_item = {
                    'question': item.get('instruction', ''),
                    'response': item.get('output', ''),
                    'is_hallucination': 1 if item.get('label') == 'hallucinated' else 0
                }
                f.write(json.dumps(formatted_item) + '\n')
    
    # Merge all training data
    print("Merging all training data...")
    with open(f"{output_dir}/train.jsonl", 'w', encoding='utf-8') as outfile:
        for category in categories:
            with open(f"{output_dir}/{category}_train.jsonl", 'r', encoding='utf-8') as infile:
                outfile.write(infile.read())
    
    # Merge all test data
    print("Merging all test data...")
    with open(f"{output_dir}/test.jsonl", 'w', encoding='utf-8') as outfile:
        for category in categories:
            with open(f"{output_dir}/{category}_test.jsonl", 'r', encoding='utf-8') as infile:
                outfile.write(infile.read())
    
    print("HaluEval dataset preparation complete!")
    print(f"Train data: {output_dir}/train.jsonl")
    print(f"Test data: {output_dir}/test.jsonl")
    
    return f"{output_dir}/train.jsonl", f"{output_dir}/test.jsonl"

# Run the function to get the paths
train_data_path, test_data_path = prepare_halueval_data_from_hf()

Loading HaluEval dataset from Hugging Face...
Loading qa dataset...
Loading dialogue dataset...
Loading summarization dataset...
Loading general dataset...
Merging all training data...
Merging all test data...
HaluEval dataset preparation complete!
Train data: data/halueval/train.jsonl
Test data: data/halueval/test.jsonl


In [10]:
# 3. Define the Epinet
class EpiNet(nn.Module):
    def __init__(self, feature_dim, z_dim, hidden_dims, num_classes):
        super().__init__()
        dims = [feature_dim + z_dim] + hidden_dims + [num_classes]
        layers = []
        for in_d, out_d in zip(dims, dims[1:]):
            layers += [nn.Linear(in_d, out_d), nn.ReLU()]
        self.mlp = nn.Sequential(*layers[:-1])  # drop final ReLU

    def forward(self, features, z):
        # stop-gradient on features
        features = features.detach()
        x = torch.cat([features, z], dim=1)
        return self.mlp(x)

# 4. Define the PriorNet
class PriorNet(nn.Module):
    def __init__(self, feature_dim, z_dim, num_classes):
        super().__init__()
        # Fixed random weights
        self.fc = nn.Linear(feature_dim + z_dim, num_classes)
        for p in self.parameters():
            p.requires_grad = False  # fix weights

    def forward(self, features, z):
        features = features.detach()
        x = torch.cat([features, z], dim=1)
        return self.fc(x)

In [11]:
# 5. Wrap into an Epistemic Neural Network
class EpistemicNN(nn.Module):
    def __init__(self, base: LlamaBaseNet, epinet: EpiNet, prior: PriorNet=None):
        super().__init__()
        self.base = base
        self.epinet = epinet
        self.prior = prior

    def forward(self, x, z):
        logits, features = self.base(x)         # base logits & features
        δ = self.epinet(features, z)            # learnable correction
        σP = self.prior(features, z) if self.prior else 0
        return logits + δ + σP

# 6. Sampling epistemic index z
def sample_z(batch_size, z_dim, device):
    # Gaussian prior
    return torch.randn(batch_size, z_dim, device=device)

In [12]:
def train_enn(model, dataloader, epochs, lr, λ, z_dim, device, writer):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=λ)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
            # Move tensors to device
            x_batch = x_batch.to(device, non_blocking=False)
            y_batch = y_batch.to(device)
            
            # Sample epistemic indices
            z = sample_z(len(y_batch), z_dim, device)

            # Forward pass
            logits = model(x_batch, z)
            loss = F.cross_entropy(logits, y_batch)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track accuracy
            total_loss += loss.item()
            _, predicted = logits.max(1)
            total += y_batch.size(0)
            correct += predicted.eq(y_batch).sum().item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}, "
                      f"Acc: {100.*correct/total:.2f}%")
                
            avg_loss = total_loss / (batch_idx + 1)
            acc      = 100. * correct / total
            global_step = epoch * len(dataloader) + batch_idx
            writer.add_scalar("train/loss", loss.item(), global_step)
            writer.add_scalar("train/acc",  acc,      global_step)

            if global_step % 50 == 0:
                writer.flush()
                
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {total_loss/(batch_idx+1):.4f}, "
              f"Accuracy: {100.*correct/total:.2f}%")

        torch.mps.empty_cache()
    
    return model

# Also update the evaluation function
def evaluate_enn(model, dataloader, z_dim, device, num_samples=10):
    model.eval()
    total = 0
    correct = 0
    epistemic_uncertainty = []
    
    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch = x_batch.to(device, non_blocking=False)
            y_batch = y_batch.to(device)
            batch_size = len(x_batch)
            
            # Sample multiple z for each input
            all_logits = []
            for _ in range(num_samples):
                z = sample_z(batch_size, z_dim, device)
                logits = model(x_batch, z)
                all_logits.append(logits)
            
            # Stack all predictions
            stacked_logits = torch.stack(all_logits)  # [num_samples, batch_size, num_classes]
            
            # Mean prediction
            mean_logits = stacked_logits.mean(dim=0)
            _, predicted = mean_logits.max(1)
            total += y_batch.size(0)
            correct += predicted.eq(y_batch).sum().item()
            
            # Calculate uncertainty - variance across samples
            uncertainty = stacked_logits.var(dim=0).sum(dim=1)  # [batch_size]
            epistemic_uncertainty.append(uncertainty)
    
    accuracy = 100. * correct / total
    avg_uncertainty = torch.cat(epistemic_uncertainty).mean().item()
    
    print(f"Test Accuracy: {accuracy:.2f}%")
    print(f"Average Epistemic Uncertainty: {avg_uncertainty:.4f}")
    
    return accuracy, avg_uncertainty