In [None]:
# Upgrade pip
!pip install --upgrade pip

# Install/upgrade required libraries
!pip install torch --quiet
!pip install torchvision --quiet
!pip install torchaudio --quiet
!pip install transformers==4.33.1 --quiet
!pip install datasets --quiet
!pip install scikit-learn --quiet
!pip install codecarbon --quiet
!pip install numpy==1.26.4 --quiet
!pip install pandas --quiet
!pip install tqdm --quiet
!pip install  spikingjelly --quiet

In [2]:
!pip install  spikingjelly --quiet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, BertTokenizer, BertForSequenceClassification, get_polynomial_decay_schedule_with_warmup
from datasets import load_dataset
from codecarbon import OfflineEmissionsTracker
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import pandas as pd
import time
import json
import os
import random
import logging
import warnings
from tqdm import tqdm
from torch.optim import AdamW

# Configuration
warnings.filterwarnings("ignore")
logging.getLogger("codecarbon").setLevel(logging.INFO)
DEVICE_CONFIG = {
    'optimize_for_gpu': True,
    'mixed_precision': False  # Disabled for SNN layer stability
}
DEVICE = "cuda" if torch.cuda.is_available() and DEVICE_CONFIG['optimize_for_gpu'] else "cpu"
print(f"Using device: {DEVICE}")
DATASETS = {
    'glue_sst2': {'name': 'glue', 'config': 'sst2', 'split_train': 'train', 'split_val': 'validation'},
    'glue_mrpc': {'name': 'glue', 'config': 'mrpc', 'split_train': 'train', 'split_val': 'validation'},
    'glue_rte': {'name': 'glue', 'config': 'rte', 'split_train': 'train', 'split_val': 'validation'}
}
MAX_SAMPLES = 10000
WATER_USAGE_FACTORS = {"average_l_per_kwh": 1.8}
CARBON_INTENSITY = 250  # gCO2e/kWh
BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = 2  # New: Gradient accumulation
SEQ_LENGTH = 128
NUM_LAYERS = 1
NUM_EPOCHS = {'sst2': 7, 'mrpc': 5, 'rte': 3}  # Increased epochs
LEARNING_RATE = 3e-5  # Increased learning rate
WEIGHT_DECAY = 0.01  # New: Weight decay for regularization
DROPOUT_RATE = 0.1  # New: Dropout rate for NSH layers
PATIENCE = 3  # New: Early stopping patience

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    device = torch.device(DEVICE)
    if DEVICE == "cuda":
        try:
            print(f"✓ Using GPU: {torch.cuda.get_device_name(0)} (CUDA)")
        except Exception:
            print("✓ Using GPU (name unknown)")
    else:
        print(f"✓ Using CPU")
    return device

# Advanced Neuromorphic & Sparse Layers
def gumbel_top_k_select(scores, k, temperature=1.0):
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-10) + 1e-10)
    perturbed_scores = (scores + gumbel_noise) / temperature
    _, top_k_indices = torch.topk(perturbed_scores, k=k, dim=-1)
    mask = torch.zeros_like(scores, dtype=torch.bool)
    mask.scatter_(-1, top_k_indices, 1)
    return mask

class SurrogateSpike(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mem, thresh=1.0):
        ctx.save_for_backward(mem - thresh)
        return (mem > thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        (mem_minus_thresh,) = ctx.saved_tensors
        grad = grad_output * (1 / (1 + torch.exp(-mem_minus_thresh * 5)))**2 * 5
        return grad, None

class LIFLayer(nn.Module):
    def __init__(self, input_dim, output_dim, threshold=1.0, decay=0.9):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.threshold = threshold
        self.decay = decay
        self.mem = None

    def forward(self, x):
        if self.mem is None or self.mem.shape != x.shape:
            self.mem = torch.zeros_like(x, device=x.device)
        current = self.linear(x)
        self.mem = self.decay * self.mem + current
        spikes = SurrogateSpike.apply(self.mem, self.threshold)
        self.mem = self.mem.detach()
        self.mem[spikes > 0] = 0
        return spikes

    def reset(self):
        self.mem = None

class BigBirdStyleAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, window_size=32, random_k=8, global_tokens=1):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.hidden_dim = hidden_dim
        self.num_heads, self.head_dim = num_heads, hidden_dim // num_heads
        self.window_size, self.random_k, self.global_tokens = window_size, random_k, global_tokens
        self.qkv_fused_linear = nn.Linear(hidden_dim, 3 * hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)  # New: Dropout for regularization

    def forward(self, x, attn_mask=None, output_attentions=False):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv_fused_linear(x)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        attention_scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.head_dim)
        window_mask = torch.ones(seq_len, seq_len, device=x.device).triu(self.window_size+1) + torch.ones(seq_len, seq_len, device=x.device).tril(-self.window_size-1)
        global_mask = torch.ones(seq_len, seq_len, device=x.device)
        global_mask[:self.global_tokens, :] = 0
        global_mask[:, :self.global_tokens] = 0
        random_mask_bool = gumbel_top_k_select(torch.randn_like(attention_scores), k=self.random_k)
        sparse_mask = (window_mask + global_mask) > 0
        sparse_mask = sparse_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, self.num_heads, -1, -1)
        sparse_mask = sparse_mask | ~random_mask_bool
        attention_scores = attention_scores.masked_fill(sparse_mask, -1e9)
        if attn_mask is not None:
            if attn_mask.dim() == 2:
                attn_mask = attn_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, seq_len, -1)
            elif attn_mask.dim() == 4:
                attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
            attention_scores = attention_scores + attn_mask
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)  # New: Apply dropout
        attention_output = torch.matmul(attention_weights, v)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
        output = self.output_layer(attention_output)
        return (output,) if not output_attentions else (output, attention_weights)

class NSH_DistilBertLayer(nn.Module):
    def __init__(self, original_layer):
        super().__init__()
        attn_module = original_layer.attention
        hidden_dim = attn_module.q_lin.in_features
        self.sparse_attention = BigBirdStyleAttention(hidden_dim, attn_module.n_heads)
        self.ffn = original_layer.ffn
        self.sa_layer_norm = original_layer.sa_layer_norm
        self.output_layer_norm = original_layer.output_layer_norm
        self.snn_layer = LIFLayer(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)  # New: Dropout for regularization

    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
        normed_x = self.sa_layer_norm(x)
        attn_output = self.sparse_attention(normed_x, attn_mask=attn_mask, output_attentions=output_attentions)
        spiking_output = self.snn_layer(attn_output[0])
        x = x + self.dropout(spiking_output)  # New: Apply dropout
        ffn_output = self.ffn(self.output_layer_norm(x))
        x = x + self.dropout(ffn_output)  # New: Apply dropout
        return (x,) if not output_attentions else (x, attn_output[1])

class NSH_BertLayer(nn.Module):
    def __init__(self, original_layer):
        super().__init__()
        attn_module = original_layer.attention.self
        hidden_dim = attn_module.query.in_features
        self.sparse_attention = BigBirdStyleAttention(hidden_dim, attn_module.num_attention_heads)
        self.attention_output_dense = original_layer.attention.output
        self.intermediate = original_layer.intermediate
        self.output = original_layer.output
        self.snn_layer = LIFLayer(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)  # New: Dropout for regularization

    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, **kwargs):
        attn_output_tuple = self.sparse_attention(hidden_states, attn_mask=attention_mask, output_attentions=output_attentions)
        spiking_output = self.snn_layer(attn_output_tuple[0])
        attention_output = self.attention_output_dense(self.dropout(spiking_output), hidden_states)  # New: Apply dropout
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return (layer_output,) if not output_attentions else (layer_output, attn_output_tuple[1])

# Model Building
def build_baseline_model(device, model_type="distilbert"):
    print(f"\n🏗️ Building Baseline {model_type.capitalize()} Model...")
    model_class = DistilBertForSequenceClassification if model_type == "distilbert" else BertForSequenceClassification
    model_name = "distilbert-base-uncased" if model_type == "distilbert" else "bert-base-uncased"
    model = model_class.from_pretrained(model_name, num_labels=2).to(device)
    print(f"✓ Successfully created baseline {model_type.capitalize()} model.")
    return model

def build_nsh_model(device, model_type="distilbert", num_layers=1):
    print(f"\n🏗️ Building NSH {model_type.capitalize()} Model with {num_layers} layer(s)...")
    model_class = DistilBertForSequenceClassification if model_type == "distilbert" else BertForSequenceClassification
    model_name = "distilbert-base-uncased" if model_type == "distilbert" else "bert-base-uncased"
    model = model_class.from_pretrained(model_name, num_labels=2)
    if model_type == "distilbert":
        for i in range(min(num_layers, len(model.distilbert.transformer.layer))):
            model.distilbert.transformer.layer[i] = NSH_DistilBertLayer(model.distilbert.transformer.layer[i])
    else:
        for i in range(min(num_layers, len(model.bert.encoder.layer))):
            model.bert.encoder.layer[i] = NSH_BertLayer(model.bert.encoder.layer[i])
    model = model.to(device)
    print(f"✓ Successfully created NSH {model_type.capitalize()} model.")
    return model

# Data Loading
def get_dataloaders(model_type, seq_len, batch_size):
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') if model_type == 'distilbert' else BertTokenizer.from_pretrained('bert-base-uncased')
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def preprocess_glue_sst2(examples):
        enc = tokenizer(examples["sentence"], padding="max_length", max_length=seq_len, truncation=True)
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": examples["label"]}

    def preprocess_glue_mrpc(examples):
        enc = tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", max_length=seq_len, truncation=True)
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": examples["label"]}

    def preprocess_glue_rte(examples):
        enc = tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", max_length=seq_len, truncation=True)
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": examples["label"]}

    dataloaders = {'train': {}, 'validation': {}}
    for dataset_name, dataset_info in DATASETS.items():
        train_max_samples = 2490 if dataset_name == 'glue_rte' else (10000 if dataset_name == 'glue_sst2' else MAX_SAMPLES)
        val_max_samples = 277 if dataset_name == 'glue_rte' else MAX_SAMPLES
        train_dataset = load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_train']).select(range(min(len(load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_train'])), train_max_samples)))
        val_dataset = load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_val']).select(range(min(len(load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_val'])), val_max_samples)))
        preprocess_fn = {'glue_sst2': preprocess_glue_sst2, 'glue_mrpc': preprocess_glue_mrpc, 'glue_rte': preprocess_glue_rte}[dataset_name]
        remove_cols = {'glue_sst2': ['sentence', 'idx'], 'glue_mrpc': ['sentence1', 'sentence2', 'idx'], 'glue_rte': ['sentence1', 'sentence2', 'idx']}[dataset_name]
        train_preprocessed = train_dataset.map(preprocess_fn, batched=True, remove_columns=remove_cols)
        val_preprocessed = val_dataset.map(preprocess_fn, batched=True, remove_columns=remove_cols)
        train_preprocessed.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        val_preprocessed.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        dataloaders['train'][dataset_name] = DataLoader(train_preprocessed, batch_size=batch_size, shuffle=True)
        dataloaders['validation'][dataset_name] = DataLoader(val_preprocessed, batch_size=batch_size, shuffle=False)
    print(f"Dataloaders created for {model_type}: {', '.join(dataloaders['validation'].keys())}")
    return dataloaders

# Evaluation Metrics
def evaluate_classification(model, dataloader, device):
    model.eval()
    predictions, labels = [], []
    total_loss = 0
    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
            preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            total_loss += outputs.loss.item()
        predictions.extend(preds)
        labels.extend(batch['labels'].cpu().numpy())
    avg_loss = total_loss / len(dataloader)
    return {"accuracy": accuracy_score(labels, predictions), "f1": f1_score(labels, predictions, average='weighted'), "val_loss": avg_loss}

# Fine-Tuning with Early Stopping
def fine_tune(model, train_dataloader, val_dataloader, device, is_bert, task_name, epochs=NUM_EPOCHS, task_type="classification", scheduler_callback=None):
    model.train()
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_dataloader) * epochs[task_name]
    num_warmup_steps = int(total_steps * 0.1)
    scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, power=1.0)

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(epochs[task_name]):
        print(f"Epoch {epoch+1}/{epochs[task_name]}")
        total_loss = 0
        for module in model.modules():
            if isinstance(module, LIFLayer):
                module.reset()
        optimizer.zero_grad()
        for i, batch in enumerate(tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")):
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids', 'labels'] and isinstance(v, torch.Tensor)}
            inputs["labels"] = batch["labels"].to(device)
            if not is_bert:
                inputs.pop('token_type_ids', None)
            
            outputs = model(**inputs)
            loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
            loss.backward()
            total_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS

            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                if scheduler_callback is not None:
                    scheduler_callback()

        avg_loss = total_loss / len(train_dataloader)
        print(f"Average training loss: {avg_loss:.4f}")

        # Validation step for early stopping
        val_metrics = evaluate_classification(model, val_dataloader, device)
        val_loss = val_metrics['val_loss']
        print(f"Validation loss: {val_loss:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model

# Experiment Runner
def run_experiment(model, model_name, dataloaders, device, is_bert, batch_size, seq_len, task_type, run_sst2=True, run_mrpc=True, run_rte=True):
    results = {
        'model_name': model_name,
        'batch_size': batch_size,
        'seq_length': seq_len,
        'task_type': task_type,
        'accuracy_metrics': {},
        'performance_metrics': {},
        'scheduler_metrics': {
            'random_k_distribution': {'k_8': 1.0},
            'avg_random_k': 8.0
        }
    }

    def update_scheduler_metrics(model, num_forward_passes):
        random_k_counts = {'k_8': 0}
        total_layers = 0
        total_random_k = 0
        for layer in model.modules():
            if isinstance(layer, (NSH_DistilBertLayer, NSH_BertLayer)):
                total_layers += 1
                random_k = layer.sparse_attention.random_k
                total_random_k += random_k
                random_k_counts['k_8'] = random_k_counts.get('k_8', 0) + (1 if random_k == 8 else 0)
        if total_layers == 0:
            total_layers = 1
        return {
            'random_k_distribution': {k: v / total_layers for k, v in random_k_counts.items()},
            'avg_random_k': total_random_k / total_layers
        }

    print(f"\n--- 🚀 Measuring Metrics for {model_name} ({task_type}) ---")
    start_time = time.time()
    num_queries = 0
    num_forward_passes = 0
    tracker = OfflineEmissionsTracker(
        project_name=f"Experiment_{model_name.replace(' ', '_')}",
        measure_power_secs=1,
        output_dir=".",
        log_level='info',
        country_iso_code="USA",
        region="California"
    )
    tracker.start()
    try:
        if run_sst2:
            print("\n--- Fine-tuning on GLUE SST-2 ---")
            model = fine_tune(model, dataloaders['train']['glue_sst2'], dataloaders['validation']['glue_sst2'], device, is_bert, task_name='sst2', task_type=task_type,
                             scheduler_callback=lambda: results.update({'scheduler_metrics': update_scheduler_metrics(model, num_forward_passes)}))
            num_forward_passes += len(dataloaders['train']['glue_sst2'].dataset) * NUM_EPOCHS['sst2']
            print(" Evaluating GLUE SST-2...")
            metrics = evaluate_classification(model, dataloaders['validation']['glue_sst2'], device)
            results['accuracy_metrics']['sst2_accuracy'] = metrics['accuracy']
            results['accuracy_metrics']['sst2_f1'] = metrics['f1']
            num_queries += len(dataloaders['validation']['glue_sst2'].dataset)
            num_forward_passes += len(dataloaders['validation']['glue_sst2'].dataset)
            print(f" SST-2 Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")

        if run_mrpc:
            print("\n--- Fine-tuning on GLUE MRPC ---")
            model = fine_tune(model, dataloaders['train']['glue_mrpc'], dataloaders['validation']['glue_mrpc'], device, is_bert, task_name='mrpc', task_type=task_type,
                             scheduler_callback=lambda: results.update({'scheduler_metrics': update_scheduler_metrics(model, num_forward_passes)}))
            num_forward_passes += len(dataloaders['train']['glue_mrpc'].dataset) * NUM_EPOCHS['mrpc']
            print(" Evaluating GLUE MRPC...")
            metrics = evaluate_classification(model, dataloaders['validation']['glue_mrpc'], device)
            results['accuracy_metrics']['mrpc_accuracy'] = metrics['accuracy']
            results['accuracy_metrics']['mrpc_f1'] = metrics['f1']
            num_queries += len(dataloaders['validation']['glue_mrpc'].dataset)
            num_forward_passes += len(dataloaders['validation']['glue_mrpc'].dataset)
            print(f" MRPC Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")

        if run_rte:
            print("\n--- Fine-tuning on GLUE RTE ---")
            model = fine_tune(model, dataloaders['train']['glue_rte'], dataloaders['validation']['glue_rte'], device, is_bert, task_name='rte', task_type=task_type,
                             scheduler_callback=lambda: results.update({'scheduler_metrics': update_scheduler_metrics(model, num_forward_passes)}))
            num_forward_passes += len(dataloaders['train']['glue_rte'].dataset) * NUM_EPOCHS['rte']
            print(" Evaluating GLUE RTE...")
            metrics = evaluate_classification(model, dataloaders['validation']['glue_rte'], device)
            results['accuracy_metrics']['rte_accuracy'] = metrics['accuracy']
            results['accuracy_metrics']['rte_f1'] = metrics['f1']
            num_queries += len(dataloaders['validation']['glue_rte'].dataset)
            num_forward_passes += len(dataloaders['validation']['glue_rte'].dataset)
            print(f" RTE Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")

        results['scheduler_metrics'] = update_scheduler_metrics(model, num_forward_passes)
        emissions_kwh = tracker.stop() or 0.0
    except Exception as e:
        print(f"🚨 Experiment failed for {model_name}: {e}")
        try:
            emissions_kwh = tracker.stop() or 0.0
        except Exception:
            emissions_kwh = 0.0

    total_duration_s = time.time() - start_time
    total_tokens_processed = num_queries * seq_len
    total_carbon_g = emissions_kwh * CARBON_INTENSITY
    results['performance_metrics'] = {
        'latency_ms_query': (total_duration_s / num_queries) * 1000 if num_queries > 0 else 0,
        'throughput_tokens_sec': total_tokens_processed / total_duration_s if total_duration_s > 0 else 0,
        'energy_wh_token': (emissions_kwh * 1000) / total_tokens_processed if total_tokens_processed > 0 else 0,
        'sci_gco2e_query': total_carbon_g / num_queries if num_queries > 0 else 0,
        'wue_avg_liters_query': (emissions_kwh * WATER_USAGE_FACTORS['average_l_per_kwh']) / num_queries if num_queries > 0 else 0,
        'total_emissions_kgco2eq': total_carbon_g / 1000,
        'total_energy_kwh': emissions_kwh
    }
    print(f"\n--- Results for {model_name} ---")
    print(f" Duration: {total_duration_s:.2f}s | Emissions: {total_carbon_g / 1000:.6f} kg CO2eq | Energy: {emissions_kwh:.6f} kWh | Queries: {num_queries}")
    print(json.dumps(results, indent=2))
    print("-" * 40)
    return results

# --- SECTION: 8. MAIN EXECUTION ---
if __name__ == "__main__":
    DEVICE = get_device()
    set_seed(42)
    batch_size = BATCH_SIZE
    num_layers = NUM_LAYERS
    all_results = []
    model_configs = [
        # Baseline Distilbert Niche wala
    #  {'model_type': 'distilbert', 'task_type': 'classification', 'is_nsh': False, 'name': 'Baseline_DistilBERT_Classification'},
        # Baseline Bert Niche wala
    #  {'model_type': 'bert', 'task_type': 'classification', 'is_nsh': False, 'name': 'Baseline_BERT_Classification'},
    
        # NSH Distilbert isko run krana hai
     {'model_type': 'distilbert', 'task_type': 'classification', 'is_nsh': True, 'name': 'NSH_DistilBERT_Classification'},
        # NSH Bert  ise bhi run krana hai
    #{'model_type': 'bert', 'task_type': 'classification', 'is_nsh': True, 'name': 'NSH_BERT_Classification'}
    ]
    for config in model_configs:
        model_type = config['model_type']
        task_type = config['task_type']
        is_nsh = config['is_nsh']
        model_name = config['name']
        seq_len = SEQ_LENGTH
        dataloaders = get_dataloaders(model_type=model_type, seq_len=seq_len, batch_size=batch_size)
        if is_nsh:
            model = build_nsh_model(device=DEVICE, model_type=model_type, num_layers=num_layers)
        else:
            model = build_baseline_model(device=DEVICE, model_type=model_type)
        results = run_experiment(model=model, model_name=model_name, dataloaders=dataloaders, device=DEVICE,
            is_bert=(model_type == 'bert'),batch_size=batch_size, seq_len=seq_len,
            task_type=task_type, run_sst2=True,run_mrpc=True,run_rte=True)
        all_results.append(results)
    with open('nsh_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)
    df_out = pd.json_normalize(all_results, sep='_')
    df_out.to_csv('nsh_results.csv', index=False)
    print("✅ CSV saved to 'nsh_results.csv'")
    print("\nAll experiments completed. Results saved.")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, BertTokenizer, BertForSequenceClassification, get_polynomial_decay_schedule_with_warmup
from datasets import load_dataset
from codecarbon import OfflineEmissionsTracker
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import pandas as pd
import time
import json
import random
import logging
import warnings
from tqdm import tqdm
from torch.optim import AdamW
import argparse # New: For easier experimentation
from spikingjelly.activation_based import neuron


from transformers import DistilBertConfig


# --- STRATEGY 1: Import the highly optimized spikingjelly library ---
from spikingjelly.activation_based import neuron
# FIX: Import 'functional' with an alias to prevent name collisions.
from spikingjelly.activation_based import functional as sj_functional

from spikingjelly.activation_based import surrogate

# --- SECTION: 1. CONFIGURATION ---
warnings.filterwarnings("ignore")
logging.getLogger("codecarbon").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

DEVICE_CONFIG = {
    'optimize_for_gpu': True,
    'mixed_precision': True
}
DEVICE = "cuda" if torch.cuda.is_available() and DEVICE_CONFIG['optimize_for_gpu'] else "cpu"

# --- Constants ---
DATASETS = {
    'glue_sst2': {'name': 'glue', 'config': 'sst2', 'split_train': 'train', 'split_val': 'validation'},
    'glue_mrpc': {'name': 'glue', 'config': 'mrpc', 'split_train': 'train', 'split_val': 'validation'},
    'glue_rte': {'name': 'glue', 'config': 'rte', 'split_train': 'train', 'split_val': 'validation'}
}
MAX_SAMPLES = 10000
CARBON_INTENSITY = 250
WATER_USAGE_FACTORS = {"average_l_per_kwh": 1.8}

# --- Hyperparameters ---
BATCH_SIZE = 16
SEQ_LENGTH = 128
NUM_LAYERS = 4
NUM_EPOCHS = {'sst2': 5, 'mrpc': 5, 'rte': 3}
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.01
DROPOUT_RATE = 0.1
PATIENCE = 3
GRADIENT_ACCUMULATION_STEPS = 2

# --- SECTION: 2. UTILITY FUNCTIONS ---
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    device = torch.device(DEVICE)
    logger.info(f"✓ Using {'GPU: ' + torch.cuda.get_device_name(0) if DEVICE == 'cuda' else 'CPU'}")
    return device

# --- SECTION: 3. ADVANCED NEUROMORPHIC & SPARSE LAYERS ---
def gumbel_top_k_select(scores, k, temperature=1.0):
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-10) + 1e-10)
    perturbed_scores = (scores + gumbel_noise) / temperature
    _, top_k_indices = torch.topk(perturbed_scores, k=k, dim=-1)
    mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, top_k_indices, 1)
    return mask

class Heaviside(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return (x > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        grad = grad_output * torch.sigmoid(x) * (1 - torch.sigmoid(x))
        return grad

class LIFLayer(nn.Module):
    def __init__(self, input_dim, output_dim, threshold=1.0, decay=0.9):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.threshold = threshold
        self.decay = decay
        self.mem = None

    def forward(self, x):
        if self.mem is None or self.mem.shape != x.shape:
            self.mem = torch.zeros_like(x, device=x.device)
        current = self.linear(x)
        self.mem = self.decay * self.mem + current
        spikes = Heaviside.apply(self.mem - self.threshold)
        self.mem = self.mem * (1 - spikes)  # Reset membrane where spikes occur
        return spikes

    def reset(self):
        self.mem = None

class BigBirdStyleAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, window_size=12, random_k=2, global_tokens=1, seq_len=128):
        super().__init__()
        self.hidden_dim, self.num_heads = hidden_dim, num_heads
        self.head_dim = hidden_dim // num_heads
        self.window_size, self.random_k, self.global_tokens = window_size, random_k, global_tokens
        self.qkv_fused_linear = nn.Linear(hidden_dim, 3 * hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.static_mask = self._create_static_mask(seq_len, window_size, global_tokens)

    def _create_static_mask(self, seq_len, window_size, global_tokens):
        window_mask = torch.ones(seq_len, seq_len).triu(window_size + 1) + torch.ones(seq_len, seq_len).tril(-window_size - 1)
        global_mask = torch.ones(seq_len, seq_len)
        global_mask[:global_tokens, :] = 0
        global_mask[:, :global_tokens] = 0
        static_mask = (window_mask + global_mask) > 0
        return static_mask.unsqueeze(0).unsqueeze(1).to(torch.bool)

    def forward(self, x, attn_mask=None, output_attentions=False):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv_fused_linear(x)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.head_dim)
        if self.static_mask.device != x.device:
            self.static_mask = self.static_mask.to(x.device)

        random_mask_bool = gumbel_top_k_select(torch.randn_like(attention_scores), k=self.random_k)
        final_mask = self.static_mask.expand(batch_size, self.num_heads, -1, -1) | random_mask_bool

        if attn_mask is not None:
            final_mask = final_mask | attn_mask

        mask_value = torch.finfo(attention_scores.dtype).min
        attention_scores = attention_scores.masked_fill(final_mask, mask_value)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        attention_output = torch.matmul(attention_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
        output = self.output_layer(attention_output)
        return (output,) if not output_attentions else (output, attention_weights)

class NSH_DistilBertLayer(nn.Module):
    def __init__(self, original_layer, seq_len=128):
        super().__init__()
        attn_module = original_layer.attention
        hidden_dim = attn_module.q_lin.in_features
        self.sparse_attention = BigBirdStyleAttention(hidden_dim, attn_module.n_heads, seq_len=seq_len)
        self.ffn = original_layer.ffn
        self.sa_layer_norm = original_layer.sa_layer_norm
        self.output_layer_norm = original_layer.output_layer_norm
        self.snn_layer = LIFLayer(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
        boolean_attn_mask = None
        if attn_mask is not None:
            boolean_attn_mask = (attn_mask == 0).unsqueeze(1).unsqueeze(2)
        normed_x = self.sa_layer_norm(x)
        attn_output_tuple = self.sparse_attention(normed_x, attn_mask=boolean_attn_mask, output_attentions=output_attentions)
        spiking_output = self.snn_layer(attn_output_tuple[0])
        x = x + self.dropout(spiking_output)
        normed_x2 = self.output_layer_norm(x)
        ffn_output = self.ffn(normed_x2)
        x = x + self.dropout(ffn_output)
        return (x,) if not output_attentions else (x, attn_output_tuple[1])

class NSH_BertLayer(nn.Module):
    def __init__(self, original_layer, seq_len=128):
        super().__init__()
        attn_module = original_layer.attention.self
        hidden_dim = attn_module.query.in_features
        self.sparse_attention = BigBirdStyleAttention(hidden_dim, attn_module.num_attention_heads, seq_len=seq_len)
        self.attention_output_dense = original_layer.attention.output
        self.intermediate = original_layer.intermediate
        self.output = original_layer.output
        self.snn_layer = LIFLayer(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        output_attentions = kwargs.get('output_attentions', False)
        boolean_attn_mask = None
        if attention_mask is not None:
            boolean_attn_mask = (attention_mask < 0)
        attn_output_tuple = self.sparse_attention(hidden_states, attn_mask=boolean_attn_mask, output_attentions=output_attentions)
        spiking_output = self.snn_layer(attn_output_tuple[0])
        attention_output = self.attention_output_dense(self.dropout(spiking_output), hidden_states)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(self.dropout(intermediate_output), attention_output)
        return (layer_output,) if not output_attentions else (layer_output, attn_output_tuple[1])

# --- SECTION: 4. MODEL BUILDING ---
def build_baseline_model(device, model_type="distilbert"):
    logger.info(f"Building Baseline {model_type.capitalize()} Model...")
    model_class = DistilBertForSequenceClassification if model_type == "distilbert" else BertForSequenceClassification
    model = model_class.from_pretrained(f"{model_type}-base-uncased", num_labels=2).to(device)
    logger.info(f"Successfully created baseline {model_type.capitalize()} model.")
    return model

def build_nsh_model(device, model_type="distilbert", num_layers=1, seq_len=128):
    logger.info(f"Building NSH {model_type.capitalize()} Model with {num_layers} layer(s)...")
    model_class = DistilBertForSequenceClassification if model_type == "distilbert" else BertForSequenceClassification
    model = model_class.from_pretrained(f"{model_type}-base-uncased", num_labels=2)
    if model_type == "distilbert":
        for i in range(min(num_layers, len(model.distilbert.transformer.layer))):
            model.distilbert.transformer.layer[i] = NSH_DistilBertLayer(model.distilbert.transformer.layer[i], seq_len)
    else:
        for i in range(min(num_layers, len(model.bert.encoder.layer))):
            model.bert.encoder.layer[i] = NSH_BertLayer(model.bert.encoder.layer[i], seq_len)
    model.to(device)
    logger.info(f"Successfully created NSH {model_type.capitalize()} model.")
    return model

# --- SECTION: 5. DATA LOADING ---
def get_dataloaders(model_type, seq_len, batch_size):
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') if model_type == 'distilbert' else BertTokenizer.from_pretrained('bert-base-uncased')
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    preprocess_map = {
        'glue_sst2': lambda exs: tokenizer(exs["sentence"], padding="max_length", max_length=seq_len, truncation=True),
        'glue_mrpc': lambda exs: tokenizer(exs["sentence1"], exs["sentence2"], padding="max_length", max_length=seq_len, truncation=True),
        'glue_rte': lambda exs: tokenizer(exs["sentence1"], exs["sentence2"], padding="max_length", max_length=seq_len, truncation=True)
    }
    dataloaders = {'train': {}, 'validation': {}}
    for dataset_name, dataset_info in DATASETS.items():
        train_max_samples = 2490 if dataset_name == 'glue_rte' else (10000 if dataset_name == 'glue_sst2' else MAX_SAMPLES)
        val_max_samples = 277 if dataset_name == 'glue_rte' else MAX_SAMPLES
        train_dataset = load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_train'])
        val_dataset = load_dataset(dataset_info['name'], dataset_info['config'], split=dataset_info['split_val'])
        train_dataset = train_dataset.select(range(min(len(train_dataset), train_max_samples)))
        val_dataset = val_dataset.select(range(min(len(val_dataset), val_max_samples)))
        train_dataset = train_dataset.map(preprocess_map[dataset_name], batched=True)
        val_dataset = val_dataset.map(preprocess_map[dataset_name], batched=True)
        train_dataset = train_dataset.rename_column("label", "labels")
        val_dataset = val_dataset.rename_column("label", "labels")
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        dataloaders['train'][dataset_name] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        dataloaders['validation'][dataset_name] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    logger.info(f"Dataloaders created for {model_type}: {', '.join(dataloaders['validation'].keys())}")
    return dataloaders

# --- SECTION: 6. TRAINING & EVALUATION ---
def evaluate_classification(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.cuda.amp.autocast(enabled=DEVICE_CONFIG['mixed_precision']):
                outputs = model(**batch)
            total_loss += outputs.loss.item()
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())
    return {
        "accuracy": accuracy_score(all_labels, all_preds),
        "f1": f1_score(all_labels, all_preds, average='weighted'),
        "val_loss": total_loss / len(dataloader)
    }

def fine_tune(model, train_dataloader, val_dataloader, device, task_name, epochs, debug=False):
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_dataloader) * epochs[task_name] // GRADIENT_ACCUMULATION_STEPS
    scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=int(total_steps * 0.1), num_training_steps=total_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=DEVICE_CONFIG['mixed_precision'])
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs.get(task_name, 3)):
        logger.info(f"Epoch {epoch+1}/{epochs.get(task_name, 3)} for {task_name}")
        model.train()
        for module in model.modules():
            if hasattr(module, 'reset'):
                module.reset()
        total_train_loss = 0
        progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Training Epoch {epoch+1}")

        for i, batch in progress_bar:
            inputs = {k: v.to(device) for k, v in batch.items()}
            with torch.cuda.amp.autocast(enabled=DEVICE_CONFIG['mixed_precision']):
                outputs = model(**inputs)
                loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
            try:
                scaler.scale(loss).backward(retain_graph=(i + 1) % GRADIENT_ACCUMULATION_STEPS != 0)
            except Exception as e:
                if debug:
                    logger.error(f"Backward pass failed at batch {i}: {e}")
                raise
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                for module in model.modules():
                    if hasattr(module, 'reset'):
                        module.reset()
            total_train_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
            progress_bar.set_postfix({'loss': total_train_loss / (i + 1)})

        val_metrics = evaluate_classification(model, val_dataloader, device)
        logger.info(f"Validation -> Loss: {val_metrics['val_loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")
        if val_metrics['val_loss'] < best_val_loss:
            best_val_loss = val_metrics['val_loss']
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                logger.info(f"Early stopping triggered after {epoch+1} epochs.")
                break
    return model

# --- SECTION: 7. EXPERIMENT RUNNER ---
def run_experiment(model, model_name, dataloaders, device, batch_size, seq_len, tasks_to_run):
    results = {
        'model_name': model_name,
        'batch_size': batch_size,
        'seq_length': seq_len,
        'accuracy_metrics': {},
        'performance_metrics': {}
    }

    print(f"\n--- 🚀 Measuring Metrics for {model_name} ---")
    start_time = time.time()
    num_queries = 0
    #tracker = OfflineEmissionsTracker(project_name=f"exp_{model_name.replace(' ', '_')}", output_dir="emissions", country_iso_code="USA", log_level="warning")
    tracker = OfflineEmissionsTracker(
    project_name=f"Experiment_{model_name.replace(' ', '_')}",
    measure_power_secs=1,
    output_dir=".",
    log_level='info',
    gpu_ids=[0],       # Track GPU ID 0 (your Tesla T4)
    # cpu_power=False    # Optional: disable CPU measurement for accurate GPU tracking
)
    tracker.start()
    
    try:
        for task in tasks_to_run:
            print(f"\n--- Running task: {task} ---")
            task_epochs = {'sst2': NUM_EPOCHS['sst2'], 'mrpc': NUM_EPOCHS['mrpc'], 'rte': NUM_EPOCHS['rte']}
            model = fine_tune(model, dataloaders['train'][f'glue_{task}'], dataloaders['validation'][f'glue_{task}'], device, task, task_epochs)
            
            print(f"Final evaluation on GLUE {task}...")
            metrics = evaluate_classification(model, dataloaders['validation'][f'glue_{task}'], device)
            results['accuracy_metrics'][f'{task}_accuracy'] = metrics['accuracy']
            results['accuracy_metrics'][f'{task}_f1'] = metrics['f1']
            num_queries += len(dataloaders['validation'][f'glue_{task}'].dataset)
            print(f"✓ {task.upper()} Results -> Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")

    except Exception as e:
        print(f"🚨 Experiment failed for {model_name}: {e}")
    finally:
        emissions_kwh = tracker.stop() or 0.0

    total_duration_s = time.time() - start_time
    total_tokens_processed = num_queries * seq_len
    total_carbon_g = emissions_kwh * CARBON_INTENSITY

    results['performance_metrics'] = {
        'latency_ms_query': (total_duration_s / num_queries) * 1000 if num_queries > 0 else 0,
        'throughput_tokens_sec': total_tokens_processed / total_duration_s if total_duration_s > 0 else 0,
        'energy_wh_token': (emissions_kwh * 1000) / total_tokens_processed if total_tokens_processed > 0 else 0,
         'wue_avg_liters_query': (emissions_kwh * WATER_USAGE_FACTORS['average_l_per_kwh']) / num_queries if num_queries > 0 else 0,
        'sci_gco2e_query': total_carbon_g / num_queries if num_queries > 0 else 0,
        'total_emissions_kgco2eq': total_carbon_g / 1000,
        'total_energy_kwh': emissions_kwh
    }
    
    print(f"\n--- ✅ Final Results for {model_name} ---")
    print(json.dumps(results, indent=2))
    return results

# --- SECTION: 8. MAIN EXECUTION & PROFILING ---
def main(args):
    DEVICE = get_device()
    set_seed(42)

    dataloaders = get_dataloaders(model_type=args.model_type, seq_len=args.seq_len, batch_size=args.batch_size)
    
    # --- MODEL SELECTION ---

    # === DISTILBERT BRANCH ===
    # Uncomment this block to run DistilBERT models
    if args.model_type == 'distilbert':
    #     
        model = build_nsh_model(DEVICE, 'distilbert', args.num_layers, args.seq_len)
        model_name = "NSH_DistilBERT"
    #     else:
    #         model = build_baseline_model(DEVICE, 'distilbert')
    #         model_name = "Baseline_DistilBERT"

    # === BERT BRANCH ===
    # Uncomment this block to run BERT models
    # if args.model_type == 'bert':
        
    #     model = build_nsh_model(DEVICE, 'bert', args.num_layers, args.seq_len)
    #     model_name = "NSH_BERT"
    
        # else:
        #     model = build_baseline_model(DEVICE, 'bert')
        #     model_name = "Baseline_BERT"
    # --- STRATEGY 2: Add torch.compile() for a massive speedup ---
    if args.use_compile:
        print("\n🔥 Activating torch.compile() for optimized performance...")
        # Note: 'max-autotune' is best for static shapes, which we have.
        model = torch.compile(model, mode="max-autotune")
    # Safety check: if no model was selected
    if 'model' not in locals():
        raise ValueError("No model was selected! Please uncomment a model branch in main().")

    # --- PROFILING SECTION ---
    if args.profile:
        print("\n--- 🔬 Starting Profiler ---")
        sample_batch = next(iter(dataloaders['train']['glue_sst2']))
        sample_batch = {k: v.to(DEVICE) for k, v in sample_batch.items()}

        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            with torch.profiler.record_function("model_inference"):
                outputs = model(**sample_batch)
                loss = outputs.loss
                loss.backward()

        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15))
        print("--- Profiling Complete ---")
        return  # Exit after profiling

  

    # --- Full Experiment ---
    all_results = []
    tasks = [t.strip() for t in args.tasks.split(',')]
    results = run_experiment(model=model, model_name=model_name, dataloaders=dataloaders, device=DEVICE,
                             batch_size=args.batch_size, seq_len=args.seq_len, tasks_to_run=tasks)
    all_results.append(results)
    
    output_filename = "nsh_results"
    with open(f'{output_filename}.json', 'w') as f:
        json.dump(all_results, f, indent=2)
    df = pd.json_normalize(all_results, sep='_')
    df.to_csv(f'{output_filename}.csv', index=False)
    print(f"\n✅ All experiments completed. Results saved to '{output_filename}.json' and '{output_filename}.csv'")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Neuromorphic Sparse Transformer Experiments")
    #parser.add_argument("--model_type", type=str, default="distilbert", choices=["distilbert", "bert"], help="Base model architecture.")
    parser.add_argument("--model_type", type=str, default="distilbert", choices=["distilbert", "bert"])

    parser.add_argument("--is_nsh", action="store_true", help="Flag to use the NSH model instead of the baseline.")
    parser.add_argument("--tasks", type=str, default="sst2,mrpc,rte", help="Comma-separated list of tasks to run (e.g., 'sst2,mrpc').")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training and evaluation.")
    parser.add_argument("--seq_len", type=int, default=SEQ_LENGTH, help="Sequence length for tokenization.")
    parser.add_argument("--num_layers", type=int, default=NUM_LAYERS, help="Number of layers to replace with NSH layers.")
    parser.add_argument("--profile", action="store_true", help="Run profiler for one batch and exit.")
    parser.add_argument("--use_compile", action="store_true", help="Enable torch.compile for optimization.")
    
    #args = parser.parse_args()
    # Ignore unknown arguments (like '-f' injected by Jupyter)
    args, _ = parser.parse_known_args()
    main(args)




vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sst2/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

sst2/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

sst2/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

mrpc/train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

mrpc/validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

mrpc/test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

rte/train-00000-of-00001.parquet:   0%|          | 0.00/584k [00:00<?, ?B/s]

rte/validation-00000-of-00001.parquet:   0%|          | 0.00/69.0k [00:00<?, ?B/s]

rte/test-00000-of-00001.parquet:   0%|          | 0.00/621k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2490 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/277 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2490 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Map:   0%|          | 0/277 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[codecarbon INFO @ 12:35:25] [setup] RAM Tracking...
[codecarbon INFO @ 12:35:25] [setup] CPU Tracking...



--- 🚀 Measuring Metrics for NSH_DistilBERT ---


 Linux OS detected: Please ensure RAPL files exist at /sys/class/powercap/intel-rapl/subsystem to measure CPU

[codecarbon INFO @ 12:35:26] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU @ 2.00GHz
[codecarbon INFO @ 12:35:26] [setup] GPU Tracking...
[codecarbon INFO @ 12:35:26] Tracking Nvidia GPU via pynvml
[codecarbon INFO @ 12:35:26] The below tracking methods have been set up:
                RAM Tracking Method: RAM power estimation model
                CPU Tracking Method: global constant
                GPU Tracking Method: pynvml
            
[codecarbon INFO @ 12:35:26] >>> Tracker's metadata:
[codecarbon INFO @ 12:35:26]   Platform system: Linux-6.6.56+-x86_64-with-glibc2.35
[codecarbon INFO @ 12:35:26]   Python version: 3.11.13
[codecarbon INFO @ 12:35:26]   CodeCarbon version: 3.0.5
[codecarbon INFO @ 12:35:26]   Available RAM : 31.350 GB
[codecarbon INFO @ 12:35:26]   CPU count: 4 thread(s) in 1 physical CPU(s)
[codecarbon INFO @ 12:35:26]   CPU model: Intel


--- Running task: sst2 ---


[codecarbon INFO @ 12:35:27] Energy consumed for RAM : 0.000006 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:35:27] Delta energy consumed for CPU with constant : 0.000012 kWh, power : 42.5 W
[codecarbon INFO @ 12:35:27] Energy consumed for All CPU : 0.000012 kWh
[codecarbon INFO @ 12:35:27] Energy consumed for all GPUs : 0.000007 kWh. Total GPU Power : 25.729004702349428 W
[codecarbon INFO @ 12:35:28] 0.000025 kWh of electricity used since the beginning.
[codecarbon INFO @ 12:35:28] Energy consumed for RAM : 0.000011 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:35:28] Delta energy consumed for CPU with constant : 0.000011 kWh, power : 42.5 W
[codecarbon INFO @ 12:35:28] Energy consumed for All CPU : 0.000023 kWh
[codecarbon INFO @ 12:35:28] Energy consumed for all GPUs : 0.000015 kWh. Total GPU Power : 28.84172936300526 W
[codecarbon INFO @ 12:35:28] 0.000049 kWh of electricity used since the beginning.
Training Epoch 1:   0%|          | 0/625 [00:00<?, ?it/s][codecarbon INFO @ 12:35

Final evaluation on GLUE sst2...


Evaluating:   7%|▋         | 4/55 [00:00<00:01, 35.95it/s][codecarbon INFO @ 12:40:16] Energy consumed for RAM : 0.001589 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:40:16] Delta energy consumed for CPU with constant : 0.000012 kWh, power : 42.5 W
[codecarbon INFO @ 12:40:16] Energy consumed for All CPU : 0.003380 kWh
[codecarbon INFO @ 12:40:16] Energy consumed for all GPUs : 0.005240 kWh. Total GPU Power : 66.61228285970783 W
[codecarbon INFO @ 12:40:16] 0.010209 kWh of electricity used since the beginning.
Evaluating:  73%|███████▎  | 40/55 [00:01<00:00, 35.25it/s][codecarbon INFO @ 12:40:17] Energy consumed for RAM : 0.001594 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:40:17] Delta energy consumed for CPU with constant : 0.000012 kWh, power : 42.5 W
[codecarbon INFO @ 12:40:17] Energy consumed for All CPU : 0.003392 kWh
[codecarbon INFO @ 12:40:17] Energy consumed for all GPUs : 0.005256 kWh. Total GPU Power : 59.765847225404755 W
[codecarbon INFO @ 12:40:17] 0.010242 kWh of el

✓ SST2 Results -> Accuracy: 0.7683, F1: 0.7683

--- Running task: mrpc ---


Training Epoch 1:   3%|▎         | 6/230 [00:00<00:19, 11.63it/s, loss=0.995][codecarbon INFO @ 12:40:18] Energy consumed for RAM : 0.001600 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:40:18] Delta energy consumed for CPU with constant : 0.000012 kWh, power : 42.5 W
[codecarbon INFO @ 12:40:18] Energy consumed for All CPU : 0.003404 kWh
[codecarbon INFO @ 12:40:18] Energy consumed for all GPUs : 0.005274 kWh. Total GPU Power : 66.14093163241185 W
[codecarbon INFO @ 12:40:18] 0.010278 kWh of electricity used since the beginning.
Training Epoch 1:   7%|▋         | 16/230 [00:01<00:19, 11.05it/s, loss=0.947][codecarbon INFO @ 12:40:19] Energy consumed for RAM : 0.001605 kWh. RAM Power : 20.0 W
[codecarbon INFO @ 12:40:19] Delta energy consumed for CPU with constant : 0.000012 kWh, power : 42.5 W
[codecarbon INFO @ 12:40:19] Energy consumed for All CPU : 0.003415 kWh
[codecarbon INFO @ 12:40:19] Energy consumed for all GPUs : 0.005293 kWh. Total GPU Power : 66.38748509056354 W
[codecarbo