In [None]:
# --- Definitive Installation for PyTorch and Mamba in Colab/Kaggle ---
# This process ensures full compatibility between the libraries.

# 1. Completely uninstall existing versions to prevent conflicts.
# The '-y' flag automatically confirms the uninstall.
print("--- Step 1: Uninstalling existing torch and mamba libraries ---")
!pip uninstall -y torch torchvision torchaudio mamba-ssm causal-conv1d

# 2. Install a specific, known-good version of PyTorch compatible with Colab's CUDA drivers.
print("\n--- Step 2: Installing a compatible version of PyTorch ---")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 3. Re-install Mamba from source. This will now compile against the new PyTorch version.
print("\n--- Step 3: Compiling and installing Mamba from source ---")
!git clone https://github.com/state-spaces/mamba /content/mamba
%cd /content/mamba
!pip install .
%cd /content/
print("DONE!")

In [None]:
!pip install mamba-ssm[causal-conv1d] --no-build-isolation
!pip install triton

In [None]:
!pip3 install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoTokenizer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc
import numpy as np
import copy
import os
import pickle
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import datetime

# --- 0. Environment Check ---
try:
    from mamba_ssm import Mamba
    print("Mamba-SSM library detected successfully.")
except ImportError:
    print("CRITICAL ERROR: 'mamba_ssm' not found.")
    print("Please run the installation script provided earlier (git clone + pip install .).")

# --- 1. Configuration (Upscaled for ~35M Parameters) ---
NUM_CLIENTS = 5
ROUNDS = 5
LOCAL_EPOCHS = 2
BATCH_SIZE = 16          # Reduced to 16 to fit the larger model in memory safely
LEARNING_RATE = 5e-5     # Lower LR is safer for larger models
MAX_SEQ_LENGTH = 128
EMBEDDING_DIM = 640      
VOCAB_SIZE = 50257
NUM_CONTEXT_FEATURES = 1
LAMBDA_SCL = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# --- 2. The Novelty: Supervised Contrastive Loss (SupConLoss) ---
class SupConLoss(nn.Module):
    """
    Supervised Contrastive Learning Loss.
    Encourages the model to pull embeddings of the same class together
    and push embeddings of different classes apart in vector space.
    """
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        features = F.normalize(features, dim=1)
        batch_size = features.shape[0]
        labels = labels.view(-1, 1)

        # Mask: 1 if labels match, 0 if not
        mask = torch.eq(labels, labels.T).float().to(device)

        # Similarity matrix
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature
        )

        # Numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # Mask out self-contrast
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # Compute Log-Likelihood
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)

        # Compute mean
        mask_sum = mask.sum(1)
        mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum)
        
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_sum

        loss = - mean_log_prob_pos
        loss = loss.mean()
        return loss

# --- 3. The Novel Architecture: Fed-Mamba-SCL (Larger) ---
class FedMambaSCL(nn.Module):
    def __init__(self, vocab_size, num_context_features):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, EMBEDDING_DIM)
        
        # Backbone: Mamba (State Space Model)
        # Increased d_state and expand for a "deeper" brain
        self.mamba = Mamba(
            d_model=EMBEDDING_DIM, 
            d_state=64,   # Doubled from 32
            d_conv=4, 
            expand=4
        )
        
        # Head 1: The Projector (Contrastive)
        # Scaled up linear layers to match embedding dim
        self.projector = nn.Sequential(
            nn.Linear(EMBEDDING_DIM + num_context_features, 256),
            nn.ReLU(),
            nn.Linear(256, 128) # Larger projection vector
        )
        
        # Head 2: The Classifier (Prediction)
        self.classifier = nn.Sequential(
            nn.Linear(EMBEDDING_DIM + num_context_features, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, context_features):
        embedded_seq = self.embedding(input_ids)
        mamba_output = self.mamba(embedded_seq)
        
        sequence_summary = mamba_output[:, -1, :]
        
        combined_features = torch.cat((sequence_summary, context_features), dim=1)
        
        proj = self.projector(combined_features)
        pred = self.classifier(combined_features)
        
        return proj, pred.squeeze(-1)

In [None]:
# --- 4. Data Preparation ---
class SqlSsmDataset(Dataset):
    def __init__(self, encodings, context_features, labels):
        self.encodings = encodings
        self.context_features = torch.tensor(context_features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['context_features'] = self.context_features[idx]
        item['labels'] = self.labels[idx]
        return item
    def __len__(self):
        return len(self.labels)

def get_client_datasets(filepath, num_clients):
    print(f"Loading data from '{filepath}' and partitioning for {num_clients} clients...")
    df = pd.read_csv(filepath)
    df = df.sample(frac=1).reset_index(drop=True)
    client_dfs = np.array_split(df, num_clients)
    client_data = []
    
    global_scaler = StandardScaler()
    all_context = df['flow_duration'].values.reshape(-1, 1)
    global_scaler.fit(all_context)
    
    with open("global_scaler.pkl", "wb") as f:
        pickle.dump(global_scaler, f)
    
    tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
    tokenizer.pad_token = tokenizer.eos_token

    for i, local_df in enumerate(client_dfs):
        texts = local_df['text'].astype(str).tolist()
        context = local_df['flow_duration'].values.reshape(-1, 1)
        labels = local_df['label'].values
        
        context_scaled = global_scaler.transform(context)
        encodings = tokenizer(texts, truncation=True, padding="max_length", max_length=MAX_SEQ_LENGTH, return_tensors='pt')
        
        dataset = SqlSsmDataset(encodings, context_scaled, labels)
        client_data.append(dataset)
        print(f"  - Client {i+1} prepared: {len(dataset)} samples.")
        
    return client_data, tokenizer

In [None]:
# --- 5. Client Training Logic ---
def client_update(client_model, train_loader, epochs):
    client_model.train()
    optimizer = AdamW(client_model.parameters(), lr=LEARNING_RATE)
    
    criterion_scl = SupConLoss(temperature=0.1)
    criterion_ce = nn.BCELoss()
    
    for epoch in range(epochs):
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            context = batch['context_features'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            proj, pred = client_model(input_ids, context)
            
            loss_contrastive = criterion_scl(proj, labels)
            loss_classification = criterion_ce(pred, labels)
            
            total_loss = loss_classification + (LAMBDA_SCL * loss_contrastive)
            
            total_loss.backward()
            optimizer.step()
            
    return client_model.state_dict()

# --- 6. Server Aggregation ---
def server_aggregate(global_model, client_weights):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        all_client_tensors = torch.stack([client_weights[i][k].float() for i in range(len(client_weights))], 0)
        global_dict[k] = torch.mean(all_client_tensors, 0)
    global_model.load_state_dict(global_dict)
    return global_model

# --- 7. Evaluation Utility ---
def evaluate_global(model, test_filepath):
    if not os.path.exists(test_filepath):
        print("  ! Test file not found, skipping evaluation.")
        return

    df_test = pd.read_csv(test_filepath)
    texts = df_test['text'].astype(str).tolist()
    
    with open("global_scaler.pkl", "rb") as f:
        scaler = pickle.load(f)
    context = scaler.transform(df_test['flow_duration'].values.reshape(-1, 1))
    labels = df_test['label'].values
    
    tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
    tokenizer.pad_token = tokenizer.eos_token
    encodings = tokenizer(texts, truncation=True, padding="max_length", max_length=MAX_SEQ_LENGTH, return_tensors='pt')
    
    test_dataset = SqlSsmDataset(encodings, context, labels)
    loader = DataLoader(test_dataset, batch_size=64)
    
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            context = batch['context_features'].to(device)
            lbs = batch['labels'].to(device)
            _, preds = model(input_ids, context)
            
            # Store probabilities for ROC
            probs = preds.cpu().numpy()
            preds_binary = (preds > 0.5).long()
            
            all_probs.extend(probs)
            all_preds.extend(preds_binary.cpu().numpy())
            all_labels.extend(lbs.cpu().numpy())
            
    acc = accuracy_score(all_labels, all_preds)
    print(f"  > Global Model Accuracy on Test Set: {acc:.2%}")
    
    # --- Generate ROC Curve ---
    try:
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        roc_auc = auc(fpr, tpr)
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (Global Model)')
        plt.legend(loc="lower right")
        
        plt.show() # Display the plot directly
        plt.close()
    except Exception as e:
        print(f"  > Could not generate ROC curve: {e}")



In [None]:
# --- 8. Main Orchestration Loop ---
def run_federated_scl():
    print("="*50)
    print("STARTING FEDERATED MAMBA SCL")
    print("="*50)
    
    if not os.path.exists('/kaggle/input/ssm-sqli/train_ssm.csv'):
        print("Error: train_ssm.csv not found.")
        return
        
    client_datasets, tokenizer = get_client_datasets('/kaggle/input/ssm-sqli/train_ssm.csv', NUM_CLIENTS)
    
    # Initialize Larger Global Model
    global_model = FedMambaSCL(vocab_size=len(tokenizer), num_context_features=NUM_CONTEXT_FEATURES).to(device)
    
    # --- Parameter Count Verification ---
    total_params = sum(p.numel() for p in global_model.parameters() if p.requires_grad)
    print(f"\n[INFO] Model Architecture Upscaled.")
    print(f"[INFO] Total Trainable Parameters: {total_params:,}")
    if total_params < 30000000:
        print("[WARNING] Parameter count is below 30M target.")
    else:
        print("[SUCCESS] Parameter count meets >30M target.")
    # ------------------------------------

    for round_num in range(ROUNDS):
        print(f"\n--- Round {round_num + 1}/{ROUNDS} ---")
        local_weights = []
        
        for i in range(NUM_CLIENTS):
            client_model = copy.deepcopy(global_model)
            loader = DataLoader(client_datasets[i], batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
            
            print(f"  [Client {i+1}] Training locally...")
            w_local = client_update(client_model, loader, LOCAL_EPOCHS)
            local_weights.append(w_local)
            
            del client_model
            torch.cuda.empty_cache()
            
        print("  [Server] Aggregating weights (FedAvg)...")
        global_model = server_aggregate(global_model, local_weights)
        evaluate_global(global_model, '/kaggle/input/ssm-sqli/test_ssm.csv')

    final_path = "./fed-mamba-scl-model-large"
    os.makedirs(final_path, exist_ok=True)
    torch.save(global_model.state_dict(), f"{final_path}/model_state_dict.pth")
    tokenizer.save_pretrained(final_path)
    print(f"\nFederated Learning Complete. Saved to {final_path}")

if __name__ == "__main__":
    run_federated_scl()