## CELL 1: Environment Setup & Load Previous Results

Load configuration and ViT model from previous notebooks.


In [1]:
# ============================================================
# CELL 1: ENVIRONMENT SETUP & IMPORTS
# ============================================================

"""
This cell:
1. Imports all required libraries
2. Loads configuration from Notebook 1
3. Loads ViT model checkpoint from Notebook 2
4. Sets up device and random seeds
"""

import os
import sys
import json
import pickle
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Vision
import torchvision
from torchvision import transforms
from PIL import Image

# Data & ML
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# timm for ViT
import timm

print("="*80)
print("NOTEBOOK 3: NEURAL STRUCTURED LEARNING & GRAPH CONSTRUCTION")
print("="*80)
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"Device Count: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    print(f"Active GPU: {torch.cuda.get_device_name(0)}")

print("="*80)

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"\n✓ Using device: {device}")

# Set random seeds
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

print(f"✓ Random seeds set to {RANDOM_SEED}")

# ============================================================
# LOAD CONFIGURATION FROM NOTEBOOK 1
# ============================================================

base_dir = Path('./novelty_files')
config_path = base_dir / 'configs' / 'notebook_01_config.json'

if not config_path.exists():
    raise FileNotFoundError(
        f"Configuration file not found: {config_path}\n"
        f"Please run Notebook 1 first to create the configuration."
    )

with open(config_path, 'r') as f:
    CONFIG = json.load(f)

print(f"\n✓ Loaded configuration from {config_path}")

# Load class mappings
dist_path = base_dir / 'splits' / 'class_distribution.json'
with open(dist_path, 'r') as f:
    dist_data = json.load(f)

class_to_idx = dist_data['class_to_idx']
idx_to_class = {int(k): v for k, v in dist_data['idx_to_class'].items()}

print(f"✓ Loaded class mappings ({len(class_to_idx)} classes)")

# Load splits
with open(base_dir / 'splits' / 'train_indices.pkl', 'rb') as f:
    train_indices = pickle.load(f)
with open(base_dir / 'splits' / 'val_indices.pkl', 'rb') as f:
    val_indices = pickle.load(f)

print(f"✓ Loaded splits: {len(train_indices):,} train, {len(val_indices):,} val")

print("\n" + "="*80)
print("INITIALIZATION COMPLETE")
print("="*80)

NOTEBOOK 3: NEURAL STRUCTURED LEARNING & GRAPH CONSTRUCTION
PyTorch Version: 2.9.1+cu128
CUDA Available: True
Device Count: 8
Active GPU: NVIDIA H200

✓ Using device: cuda:0
✓ Random seeds set to 42

✓ Loaded configuration from novelty_files/configs/notebook_01_config.json
✓ Loaded class mappings (8 classes)
✓ Loaded splits: 53,097 train, 11,379 val

INITIALIZATION COMPLETE


## CELL 2: Reload Dataset & ViT Model

Reload the dataset and ViT-Base model checkpoint.


In [2]:
# ============================================================
# CELL 2: RELOAD DATASET & VIT MODEL
# ============================================================

"""
This cell:
1. Reloads the HMDB51 Fight dataset samples
2. Loads the ViT-Base model from Notebook 2 checkpoint
3. Prepares model for feature extraction
"""

print("\n" + "="*80)
print("RELOADING DATASET & MODEL")
print("="*80)

# ============================================================
# RELOAD DATASET SAMPLES
# ============================================================

class HMDB51FightDataset(Dataset):
    """
    Simple dataset for loading HMDB51 Fight images.
    """
    def __init__(self, root_dir: str, split: str, class_to_idx: Dict[str, int]):
        self.root_dir = Path(root_dir)
        self.split = split
        self.class_to_idx = class_to_idx
        self.samples = []
        
        split_dir = self.root_dir / split
        for class_name, class_idx in class_to_idx.items():
            class_dir = split_dir / class_name
            if class_dir.exists():
                image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))
                for img_path in image_files:
                    self.samples.append({
                        'path': str(img_path),
                        'label': class_idx,
                        'class_name': class_name
                    })
    
    def __len__(self):
        return len(self.samples)

# Load dataset
dataset_path = CONFIG['dataset_path']
train_dataset_loader = HMDB51FightDataset(dataset_path, 'train', class_to_idx)
test_dataset_loader = HMDB51FightDataset(dataset_path, 'test', class_to_idx)

all_samples = train_dataset_loader.samples + test_dataset_loader.samples

print(f"✓ Reloaded {len(all_samples):,} samples")

# ============================================================
# LOAD VIT MODEL FROM NOTEBOOK 2 CHECKPOINT
# ============================================================

vit_checkpoint_path = base_dir / 'checkpoints' / 'vit_baseline.pt'

if not vit_checkpoint_path.exists():
    raise FileNotFoundError(
        f"ViT checkpoint not found: {vit_checkpoint_path}\n"
        f"Please run Notebook 2 first to train the ViT baseline."
    )

print(f"\nLoading ViT model from: {vit_checkpoint_path}")

# Create ViT-Base model
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=8)

# Load checkpoint
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

val_acc = checkpoint['val_accuracy']
print(f"✓ Loaded ViT checkpoint (val_acc={val_acc:.2f}%)")

# Move to device
model = model.to(device)
model.eval()  # Set to evaluation mode

print(f"✓ Model moved to {device} and set to eval mode")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model parameters: {total_params/1e6:.1f}M")

print("\n" + "="*80)


RELOADING DATASET & MODEL
✓ Reloaded 75,855 samples

Loading ViT model from: novelty_files/checkpoints/vit_baseline.pt
✓ Loaded ViT checkpoint (val_acc=97.39%)
✓ Model moved to cuda:0 and set to eval mode
✓ Model parameters: 85.8M



## CELL 3: Extract Features from ViT Backbone (RESUME-SAFE)

Extract feature embeddings from the ViT backbone (before classification head).  
**CRITICAL**: Checks if features already exist before extracting.


In [3]:
# ============================================================
# CELL 3: EXTRACT FEATURES (RESUME-SAFE)
# ============================================================

"""
This cell extracts feature embeddings from the ViT backbone.

Feature extraction:
- Takes images as input
- Passes through ViT patch embedding + transformer blocks
- Extracts [CLS] token representation (768-dim for ViT-Base)
- Saves features to disk for graph construction

RESUME-SAFE: Checks if features already exist before extraction.
"""

print("\n" + "="*80)
print("FEATURE EXTRACTION FROM VIT BACKBONE")
print("="*80)

# ============================================================
# CHECK IF FEATURES ALREADY EXIST
# ============================================================

features_dir = base_dir / 'features'
train_features_path = features_dir / 'train_features.pt'
val_features_path = features_dir / 'val_features.pt'

if train_features_path.exists() and val_features_path.exists():
    print("✓ Found existing feature files, loading instead of extracting...")
    
    train_features = torch.load(train_features_path)
    val_features = torch.load(val_features_path)
    
    print(f"✓ Loaded train features: {train_features.shape}")
    print(f"✓ Loaded val features: {val_features.shape}")
    
else:
    print("No existing features found. Extracting features...")
    print("This will take approximately 10-15 minutes.")
    
    # ============================================================
    # DEFINE FEATURE EXTRACTOR
    # ============================================================
    
    class FeatureExtractor(nn.Module):
        """
        Wrapper to extract features from ViT backbone.
        
        Returns the [CLS] token representation (before classification head).
        For ViT-Base, this is a 768-dimensional vector.
        """
        def __init__(self, vit_model):
            super().__init__()
            self.vit = vit_model
        
        def forward(self, x):
            # Forward through patch embedding and transformer
            x = self.vit.patch_embed(x)
            x = self.vit._pos_embed(x)
            x = self.vit.blocks(x)
            x = self.vit.norm(x)
            
            # Extract [CLS] token (first token)
            cls_token = x[:, 0]
            
            return cls_token
    
    feature_extractor = FeatureExtractor(model).to(device)
    feature_extractor.eval()
    
    print("✓ Feature extractor created")
    
    # ============================================================
    # CREATE DATASET FOR FEATURE EXTRACTION
    # ============================================================
    
    class HMDB51Dataset(Dataset):
        def __init__(self, samples, indices, transform=None):
            self.samples = [samples[i] for i in indices]
            self.transform = transform
        
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            sample = self.samples[idx]
            img = Image.open(sample['path']).convert('RGB')
            
            if self.transform:
                img = self.transform(img)
            
            return img, sample['label']
    
    # Define transform (same as validation)
    vit_mean = [0.485, 0.456, 0.406]
    vit_std = [0.229, 0.224, 0.225]
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=vit_mean, std=vit_std),
    ])
    
    # Create datasets
    train_dataset = HMDB51Dataset(all_samples, train_indices, transform=transform)
    val_dataset = HMDB51Dataset(all_samples, val_indices, transform=transform)
    
    # Create dataloaders
    batch_size = 128  # Larger batch for feature extraction
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,  # Keep order for indexing
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"✓ Created dataloaders (batch_size={batch_size})")
    
    # ============================================================
    # EXTRACT TRAIN FEATURES
    # ============================================================
    
    print(f"\nExtracting train features ({len(train_dataset):,} samples)...")
    
    train_features_list = []
    
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            features = feature_extractor(images)
            train_features_list.append(features.cpu())
            
            if (batch_idx + 1) % 100 == 0:
                print(f"  Processed {(batch_idx + 1) * batch_size:,} / {len(train_dataset):,} samples")
    
    train_features = torch.cat(train_features_list, dim=0)
    print(f"✓ Train features extracted: {train_features.shape}")
    
    # ============================================================
    # EXTRACT VAL FEATURES
    # ============================================================
    
    print(f"\nExtracting val features ({len(val_dataset):,} samples)...")
    
    val_features_list = []
    
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(val_loader):
            images = images.to(device)
            features = feature_extractor(images)
            val_features_list.append(features.cpu())
            
            if (batch_idx + 1) % 50 == 0:
                print(f"  Processed {(batch_idx + 1) * batch_size:,} / {len(val_dataset):,} samples")
    
    val_features = torch.cat(val_features_list, dim=0)
    print(f"✓ Val features extracted: {val_features.shape}")
    
    # ============================================================
    # SAVE FEATURES
    # ============================================================
    
    print(f"\nSaving features...")
    
    torch.save(train_features, train_features_path)
    print(f"✓ Saved train features to: {train_features_path}")
    
    torch.save(val_features, val_features_path)
    print(f"✓ Saved val features to: {val_features_path}")

# ============================================================
# VERIFY FEATURE SHAPES
# ============================================================

print("\n" + "-"*80)
print("FEATURE STATISTICS")
print("-"*80)
print(f"Train features shape: {train_features.shape}")
print(f"Val features shape:   {val_features.shape}")
print(f"Feature dimension:    {train_features.shape[1]} (ViT-Base embedding)")

# Check feature statistics
print(f"\nTrain features statistics:")
print(f"  Mean: {train_features.mean():.4f}")
print(f"  Std:  {train_features.std():.4f}")
print(f"  Min:  {train_features.min():.4f}")
print(f"  Max:  {train_features.max():.4f}")

print("\n" + "="*80)
print("✓ FEATURE EXTRACTION COMPLETE")
print("="*80)


FEATURE EXTRACTION FROM VIT BACKBONE
No existing features found. Extracting features...
This will take approximately 10-15 minutes.
✓ Feature extractor created
✓ Created dataloaders (batch_size=128)

Extracting train features (53,097 samples)...
  Processed 12,800 / 53,097 samples
  Processed 25,600 / 53,097 samples
  Processed 38,400 / 53,097 samples
  Processed 51,200 / 53,097 samples
✓ Train features extracted: torch.Size([53097, 768])

Extracting val features (11,379 samples)...
  Processed 6,400 / 11,379 samples
✓ Val features extracted: torch.Size([11379, 768])

Saving features...
✓ Saved train features to: novelty_files/features/train_features.pt
✓ Saved val features to: novelty_files/features/val_features.pt

--------------------------------------------------------------------------------
FEATURE STATISTICS
--------------------------------------------------------------------------------
Train features shape: torch.Size([53097, 768])
Val features shape:   torch.Size([11379, 768

## CELL 4: Build k-NN Graph (RESUME-SAFE)

Construct k-nearest neighbors graph for Neural Structured Learning.  
**k=5** neighbors per sample.


In [4]:
# ============================================================
# CELL 4: BUILD K-NN GRAPH (RESUME-SAFE)
# ============================================================

"""
This cell constructs a k-nearest neighbors graph for NSL regularization.

Graph construction:
- For each training sample, find k=5 nearest neighbors in feature space
- Create edge list connecting each node to its neighbors
- Save graph structure for use in training

RESUME-SAFE: Checks if graph already exists before construction.
"""

print("\n" + "="*80)
print("K-NN GRAPH CONSTRUCTION")
print("="*80)

# ============================================================
# CHECK IF GRAPH ALREADY EXISTS
# ============================================================

graphs_dir = base_dir / 'graphs'
graph_path = graphs_dir / 'train_knn_graph.pkl'
neighbors_path = graphs_dir / 'train_neighbors.npy'

if graph_path.exists() and neighbors_path.exists():
    print("✓ Found existing graph files, loading instead of constructing...")
    
    with open(graph_path, 'rb') as f:
        graph_data = pickle.load(f)
    
    neighbors_indices = np.load(neighbors_path)
    
    print(f"✓ Loaded graph data:")
    print(f"  Num nodes: {graph_data['num_nodes']}")
    print(f"  Num edges: {graph_data['num_edges']}")
    print(f"  k neighbors: {graph_data['k']}")
    print(f"  Neighbors shape: {neighbors_indices.shape}")
    
else:
    print("No existing graph found. Constructing k-NN graph...")
    print("This will take approximately 5-10 minutes.")
    
    # ============================================================
    # GRAPH CONSTRUCTION PARAMETERS
    # ============================================================
    
    k = 5  # Number of nearest neighbors
    print(f"\nGraph parameters:")
    print(f"  k (neighbors per node): {k}")
    print(f"  Number of nodes: {len(train_features):,}")
    
    # ============================================================
    # BUILD K-NN GRAPH USING SKLEARN
    # ============================================================
    
    print(f"\nFitting k-NN model on train features...")
    
    # Use cosine similarity for feature space
    # (L2 normalized features work well with cosine distance)
    nbrs = NearestNeighbors(
        n_neighbors=k+1,  # +1 because first neighbor is the point itself
        algorithm='auto',
        metric='cosine',
        n_jobs=-1  # Use all CPU cores
    )
    
    # Fit on training features
    nbrs.fit(train_features.numpy())
    
    print(f"✓ k-NN model fitted")
    
    # ============================================================
    # FIND NEIGHBORS FOR ALL TRAINING SAMPLES
    # ============================================================
    
    print(f"\nFinding {k} nearest neighbors for each sample...")
    
    distances, indices = nbrs.kneighbors(train_features.numpy())
    
    # Remove self (first column is always the point itself with distance 0)
    neighbors_indices = indices[:, 1:]  # Shape: (num_samples, k)
    neighbors_distances = distances[:, 1:]  # Shape: (num_samples, k)
    
    print(f"✓ Neighbors found")
    print(f"  Neighbors shape: {neighbors_indices.shape}")
    print(f"  Distances shape: {neighbors_distances.shape}")
    
    # ============================================================
    # CREATE EDGE LIST
    # ============================================================
    
    print(f"\nCreating edge list...")
    
    num_nodes = len(train_features)
    edge_list = []
    
    for i in range(num_nodes):
        for j in range(k):
            neighbor_idx = neighbors_indices[i, j]
            distance = neighbors_distances[i, j]
            edge_list.append((i, neighbor_idx, distance))
    
    num_edges = len(edge_list)
    
    print(f"✓ Edge list created")
    print(f"  Number of edges: {num_edges:,}")
    print(f"  Average degree: {num_edges / num_nodes:.1f}")
    
    # ============================================================
    # SAVE GRAPH DATA
    # ============================================================
    
    print(f"\nSaving graph data...")
    
    graph_data = {
        'num_nodes': num_nodes,
        'num_edges': num_edges,
        'k': k,
        'edge_list': edge_list,
        'metric': 'cosine',
    }
    
    with open(graph_path, 'wb') as f:
        pickle.dump(graph_data, f)
    print(f"✓ Saved graph data to: {graph_path}")
    
    np.save(neighbors_path, neighbors_indices)
    print(f"✓ Saved neighbor indices to: {neighbors_path}")

# ============================================================
# VERIFY GRAPH STRUCTURE
# ============================================================

print("\n" + "-"*80)
print("GRAPH STATISTICS")
print("-"*80)
print(f"Total nodes:  {graph_data['num_nodes']:,}")
print(f"Total edges:  {graph_data['num_edges']:,}")
print(f"k neighbors:  {graph_data['k']}")
print(f"Metric:       {graph_data['metric']}")

# Sample a few nodes and show their neighbors
print(f"\nSample neighbors (first 5 nodes):")
for i in range(min(5, len(neighbors_indices))):
    print(f"  Node {i}: neighbors = {neighbors_indices[i].tolist()}")

print("\n" + "="*80)
print("✓ GRAPH CONSTRUCTION COMPLETE")
print("="*80)


K-NN GRAPH CONSTRUCTION
No existing graph found. Constructing k-NN graph...
This will take approximately 5-10 minutes.

Graph parameters:
  k (neighbors per node): 5
  Number of nodes: 53,097

Fitting k-NN model on train features...
✓ k-NN model fitted

Finding 5 nearest neighbors for each sample...
✓ Neighbors found
  Neighbors shape: (53097, 5)
  Distances shape: (53097, 5)

Creating edge list...
✓ Edge list created
  Number of edges: 265,485
  Average degree: 5.0

Saving graph data...
✓ Saved graph data to: novelty_files/graphs/train_knn_graph.pkl
✓ Saved neighbor indices to: novelty_files/graphs/train_neighbors.npy

--------------------------------------------------------------------------------
GRAPH STATISTICS
--------------------------------------------------------------------------------
Total nodes:  53,097
Total edges:  265,485
k neighbors:  5
Metric:       cosine

Sample neighbors (first 5 nodes):
  Node 0: neighbors = [8124, 18642, 51242, 23757, 42690]
  Node 1: neighbors 

## CELL 5: Implement NSL Loss Functions

Implement Virtual Adversarial Training (VAT) and L2 Neighbor Regularization losses.


In [5]:
# ============================================================
# CELL 5: NSL LOSS FUNCTIONS
# ============================================================

"""
This cell implements the Neural Structured Learning loss components:

1. Virtual Adversarial Training (VAT) Loss:
   - Adds small adversarial perturbations to inputs
   - Encourages smooth predictions under perturbations
   
2. L2 Neighbor Regularization:
   - Encourages similar predictions for k-NN neighbors
   - Uses graph structure from Cell 4

3. Combined NSL Loss:
   - Total loss = CE_loss + λ_vat * VAT_loss + λ_neighbor * L2_loss
"""

print("\n" + "="*80)
print("NSL LOSS FUNCTIONS")
print("="*80)

# ============================================================
# VIRTUAL ADVERSARIAL TRAINING (VAT) LOSS
# ============================================================

def virtual_adversarial_loss(model, x, eps=8.0/255.0, xi=1e-6, num_iters=1):
    """
    Compute Virtual Adversarial Training (VAT) loss.
    
    VAT encourages the model to produce consistent predictions
    when small adversarial perturbations are added to the input.
    
    Args:
        model: Neural network model
        x: Input images (batch_size, 3, 224, 224)
        eps: Perturbation magnitude (default: 8/255 for image inputs)
        xi: Small constant for numerical stability
        num_iters: Number of power iterations for adversarial direction
    
    Returns:
        vat_loss: VAT regularization loss (scalar)
    """
    # Get prediction for original input
    with torch.no_grad():
        pred_orig = F.softmax(model(x), dim=1)
    
    # Generate random perturbation direction
    d = torch.randn_like(x)
    d = d / (torch.norm(d, p=2, dim=(1,2,3), keepdim=True) + xi)
    
    # Power iteration to find adversarial direction
    for _ in range(num_iters):
        d.requires_grad_(True)
        pred_perturbed = F.softmax(model(x + xi * d), dim=1)
        
        # KL divergence between original and perturbed predictions
        kl_div = F.kl_div(
            pred_perturbed.log(),
            pred_orig,
            reduction='batchmean'
        )
        
        # Compute gradient of KL w.r.t. perturbation
        grad = torch.autograd.grad(kl_div, d)[0]
        d = grad / (torch.norm(grad, p=2, dim=(1,2,3), keepdim=True) + xi)
        d = d.detach()
    
    # Compute final VAT loss with scaled perturbation
    d = eps * d
    pred_adv = F.softmax(model(x + d), dim=1)
    
    vat_loss = F.kl_div(
        pred_adv.log(),
        pred_orig,
        reduction='batchmean'
    )
    
    return vat_loss


# ============================================================
# L2 NEIGHBOR REGULARIZATION LOSS
# ============================================================

def l2_neighbor_loss(model, x, neighbor_x, temperature=1.0):
    """
    Compute L2 neighbor regularization loss.
    
    Encourages the model to produce similar predictions for
    samples that are neighbors in the k-NN graph.
    
    Args:
        model: Neural network model
        x: Input images (batch_size, 3, 224, 224)
        neighbor_x: Neighbor images (batch_size, 3, 224, 224)
        temperature: Softmax temperature for predictions
    
    Returns:
        neighbor_loss: L2 distance between predictions (scalar)
    """
    # Get predictions for original samples
    pred = F.softmax(model(x) / temperature, dim=1)
    
    # Get predictions for neighbor samples
    pred_neighbor = F.softmax(model(neighbor_x) / temperature, dim=1)
    
    # Compute L2 distance between prediction distributions
    neighbor_loss = F.mse_loss(pred, pred_neighbor, reduction='mean')
    
    return neighbor_loss


# ============================================================
# COMBINED NSL LOSS
# ============================================================

def nsl_loss(model, x, labels, neighbor_x=None, 
             lambda_ce=1.0, lambda_vat=0.1, lambda_neighbor=0.1):
    """
    Compute combined Neural Structured Learning loss.
    
    Total loss = λ_ce * CE_loss + λ_vat * VAT_loss + λ_neighbor * L2_loss
    
    Args:
        model: Neural network model
        x: Input images
        labels: Ground truth labels
        neighbor_x: Neighbor images (optional, for L2 regularization)
        lambda_ce: Weight for classification loss
        lambda_vat: Weight for VAT loss
        lambda_neighbor: Weight for neighbor regularization
    
    Returns:
        total_loss: Combined NSL loss
        loss_dict: Dictionary with individual loss components
    """
    # Classification loss (standard cross-entropy)
    outputs = model(x)
    ce_loss = F.cross_entropy(outputs, labels)
    
    # Virtual adversarial loss
    vat_loss_val = virtual_adversarial_loss(model, x)
    
    # Neighbor regularization loss (if neighbors provided)
    if neighbor_x is not None:
        neighbor_loss_val = l2_neighbor_loss(model, x, neighbor_x)
    else:
        neighbor_loss_val = torch.tensor(0.0, device=x.device)
    
    # Combined loss
    total_loss = (
        lambda_ce * ce_loss +
        lambda_vat * vat_loss_val +
        lambda_neighbor * neighbor_loss_val
    )
    
    loss_dict = {
        'total': total_loss.item(),
        'ce': ce_loss.item(),
        'vat': vat_loss_val.item(),
        'neighbor': neighbor_loss_val.item() if neighbor_x is not None else 0.0
    }
    
    return total_loss, loss_dict


print("✓ NSL loss functions defined:")
print("  • virtual_adversarial_loss() - VAT regularization")
print("  • l2_neighbor_loss() - Neighbor consistency")
print("  • nsl_loss() - Combined NSL loss")

print("\n" + "="*80)
print("✓ NSL COMPONENTS READY")
print("="*80)


NSL LOSS FUNCTIONS
✓ NSL loss functions defined:
  • virtual_adversarial_loss() - VAT regularization
  • l2_neighbor_loss() - Neighbor consistency
  • nsl_loss() - Combined NSL loss

✓ NSL COMPONENTS READY


## CELL 6: Notebook 3 Summary & Completion

Summary of NSL components and next steps.


In [6]:
# ============================================================
# CELL 6: NOTEBOOK 3 COMPLETION SUMMARY
# ============================================================

print("\n" + "="*80)
print("NOTEBOOK 3: NSL & GRAPHS - COMPLETION SUMMARY")
print("="*80)

# Verify all required files exist
required_files = {
    'Train Features': base_dir / 'features' / 'train_features.pt',
    'Val Features': base_dir / 'features' / 'val_features.pt',
    'k-NN Graph': base_dir / 'graphs' / 'train_knn_graph.pkl',
    'Neighbor Indices': base_dir / 'graphs' / 'train_neighbors.npy',
}

print("\nVerifying required files:")
all_files_exist = True

for file_desc, file_path in required_files.items():
    exists = file_path.exists()
    status = "✓" if exists else "✗"
    print(f"  {status} {file_desc}: {file_path}")
    if not exists:
        all_files_exist = False

if all_files_exist:
    print("\n✓ All required files successfully created!")
else:
    print("\n✗ ERROR: Some required files are missing!")

# Summary of outputs
print("\n" + "-"*80)
print("NSL COMPONENTS SUMMARY")
print("-"*80)

print(f"\n✓ Feature Extraction:")
print(f"  • Train features: {train_features.shape}")
print(f"  • Val features: {val_features.shape}")
print(f"  • Feature dimension: {train_features.shape[1]} (ViT-Base)")

print(f"\n✓ Graph Construction:")
print(f"  • Nodes: {graph_data['num_nodes']:,} (training samples)")
print(f"  • Edges: {graph_data['num_edges']:,}")
print(f"  • k-NN: {graph_data['k']} neighbors per node")
print(f"  • Metric: {graph_data['metric']}")

print(f"\n✓ NSL Loss Functions:")
print(f"  • virtual_adversarial_loss() - Implemented")
print(f"  • l2_neighbor_loss() - Implemented")
print(f"  • nsl_loss() - Combined loss function")

# Expected impact
print("\n" + "-"*80)
print("EXPECTED IMPACT")
print("-"*80)
print("NSL regularization typically provides:")
print("  • +2-3% accuracy improvement")
print("  • Better generalization on unseen data")
print("  • More robust predictions")
print("  • Smoother decision boundaries")

# Next steps
print("\n" + "-"*80)
print("NEXT STEPS")
print("-"*80)
print("✓ Notebook 3 COMPLETE: NSL & Graph Construction")
print("\nNSL components are ready for use in Notebook 4 training!")
print("\nNote: Notebook 4 (DDP Training) is already running.")
print("      These NSL components will be used in Notebook 6")
print("      for adversarial fine-tuning.")

print("\n" + "="*80)
print("NOTEBOOK 3: ✓ SUCCESSFULLY COMPLETED")
print("="*80)

# Save completion status
completion_status = {
    'notebook': 'Notebook 3: NSL & Graphs',
    'completed': True,
    'timestamp': pd.Timestamp.now().isoformat(),
    'outputs': {
        'train_features_shape': list(train_features.shape),
        'val_features_shape': list(val_features.shape),
        'graph_nodes': int(graph_data['num_nodes']),
        'graph_edges': int(graph_data['num_edges']),
        'k_neighbors': graph_data['k'],
    },
    'nsl_functions': [
        'virtual_adversarial_loss',
        'l2_neighbor_loss',
        'nsl_loss'
    ]
}

completion_path = base_dir / 'logs' / 'notebook_03_completion.json'
with open(completion_path, 'w') as f:
    json.dump(completion_status, f, indent=2)

print(f"\n✓ Completion status saved to: {completion_path}")


NOTEBOOK 3: NSL & GRAPHS - COMPLETION SUMMARY

Verifying required files:
  ✓ Train Features: novelty_files/features/train_features.pt
  ✓ Val Features: novelty_files/features/val_features.pt
  ✓ k-NN Graph: novelty_files/graphs/train_knn_graph.pkl
  ✓ Neighbor Indices: novelty_files/graphs/train_neighbors.npy

✓ All required files successfully created!

--------------------------------------------------------------------------------
NSL COMPONENTS SUMMARY
--------------------------------------------------------------------------------

✓ Feature Extraction:
  • Train features: torch.Size([53097, 768])
  • Val features: torch.Size([11379, 768])
  • Feature dimension: 768 (ViT-Base)

✓ Graph Construction:
  • Nodes: 53,097 (training samples)
  • Edges: 265,485
  • k-NN: 5 neighbors per node
  • Metric: cosine

✓ NSL Loss Functions:
  • virtual_adversarial_loss() - Implemented
  • l2_neighbor_loss() - Implemented
  • nsl_loss() - Combined loss function

----------------------------------