In [1]:
"""02_batch_tensors_and_benchmark.ipynb

## Rationale and Approach

This notebook optimizes data loading by materializing CustomDataset batches into persistent tensor files, then benchmarking the performance improvement.

Key objectives:
1. Export CustomDataset batches to .pt files with metadata preservation
2. Create TensorDataset from saved tensors for faster loading
3. Benchmark performance difference between on-the-fly vs pre-materialized loading
4. Demonstrate training loop compatibility with the new dataset format

Constraints:
- Do not modify meta.ipynb or Double_input_transformer.py; import and reuse
- Maintain full compatibility with existing data structures
- Self-contained export and benchmark tool
"""

import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path for imports
sys.path.append(str(Path("..").resolve()))

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
import json
from typing import Dict, List, Tuple, Optional

# Set up paths
ROOT = Path("..").resolve()
DATA_DIR = ROOT / "data"
TENSOR_DIR = ROOT / "notebooks_sandbox" / "tensor_batches"
TENSOR_DIR.mkdir(parents=True, exist_ok=True)

print(f"=== Batch Tensor Export and Benchmark Setup ===")
print(f"Project root: {ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Tensor output directory: {TENSOR_DIR}")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

=== Batch Tensor Export and Benchmark Setup ===
Project root: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New
Data directory: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/data
Tensor output directory: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches
PyTorch version: 2.7.1+cu128
NumPy version: 1.26.4
Pandas version: 2.3.3


In [2]:
# Cell 2: Import CustomDataset and check data availability
print("=== Importing CustomDataset and Checking Data ===")

# Import the CustomDataset from Double_input_transformer.py
exec(open(ROOT / "Double_input_transformer.py").read())

print("CustomDataset imported successfully")

# Check if the required CSV file exists
merged_zoo_path = DATA_DIR / "Merged zoo.csv"
if not merged_zoo_path.exists():
    print(f"‚ùå Merged zoo.csv not found at: {merged_zoo_path}")
    print("Please ensure the data file is available before proceeding.")
else:
    print(f"‚úÖ Merged zoo.csv found at: {merged_zoo_path}")

# Examine the CustomDataset structure
print("\nCustomDataset parameters:")
print("- L_exp: List of experiments (label1, label2, epoch_key, activ_key)")
print("- batch_size: Default 20")
print("- batch_limit: Default 5")
print("- df_path: Path to Merged zoo.csv (IMPORTANT: Must use absolute path)")

# Let's examine the batchify function
print("\nBatchify function:")
print("- Creates batches from L_exp with given batch_size and limit")

# Create a sample L_exp to understand the format
sample_L_exp = [
    (0, 1, 0, 0),  # (label1, label2, epoch_key, activ_key)
    (2, 3, 1, 1),
    (4, 5, 2, 2),
    (6, 7, 3, 3),
    (8, 9, 4, 4),
    (0, 2, 5, 5),
    (1, 3, 0, 0),
    (4, 6, 1, 1),
    (5, 7, 2, 2),
    (8, 0, 3, 3),
]

print(f"\nSample L_exp format: {sample_L_exp[0]}")

# IMPORTANT: Always use absolute path to avoid path resolution issues
print(f"\nüîß Fixing path issue - using absolute path: {merged_zoo_path}")

try:
    # Test CustomDataset with correct absolute path
    sample_dataset = CustomDataset(
        L_exp=sample_L_exp,
        batch_size=4,
        batch_limit=2,
        df_path=str(merged_zoo_path)  # CRITICAL: Always use absolute path
    )
    
    print(f"‚úÖ CustomDataset created successfully!")
    print(f"Epoch mapping: {sample_dataset.D_epoch}")
    print(f"Activation mapping: {sample_dataset.D_activ}")
    print(f"Dataset length: {len(sample_dataset)} batches")
    
    # Check what data is actually available in the CSV
    try:
        df_sample = pd.read_csv(merged_zoo_path, nrows=5)
        print(f"\nüìä CSV Structure:")
        print(f"   Shape: {df_sample.shape}")
        print(f"   Available columns: {list(df_sample.columns[:10])}...")
        
        # Check for activation columns
        activation_cols = [col for col in df_sample.columns if col in ["silu", "gelu", "relu", "leakyrelu", "sigmoid", "tanh"]]
        print(f"   Activation columns found: {activation_cols}")
        
        # Check unique epochs
        if 'epoch' in df_sample.columns:
            unique_epochs = sorted(df_sample['epoch'].unique())
            print(f"   Unique epochs: {unique_epochs}")
        
        # Check unique labels
        if 'label' in df_sample.columns:
            unique_labels = sorted(df_sample['label'].unique())
            print(f"   Unique labels: {unique_labels}")
        
        print("‚úÖ Data structure appears valid for single initialization analysis")
        
    except Exception as e:
        print(f"‚ö†Ô∏è Error examining CSV: {e}")
        
except Exception as e:
    print(f"‚ùå Error creating CustomDataset: {e}")
    print(f"This might be due to:")
    print(f"   - Missing data file")
    print(f"   - Incompatible data format")
    print(f"   - Path resolution issues")
    
    # Still try to examine the CSV structure for debugging
    try:
        df_sample = pd.read_csv(merged_zoo_path, nrows=5)
        print(f"\nüìä CSV Structure (for debugging):")
        print(f"   Shape: {df_sample.shape}")
        print(f"   Available columns: {list(df_sample.columns[:15])}...")
        print(f"   First few rows:")
        print(df_sample.head(2))
    except Exception as csv_error:
        print(f"‚ùå CSV also unreadable: {csv_error}")

print(f"\nüìù IMPORTANT NOTE:")
print(f"   Always use absolute paths when creating CustomDataset instances")
print(f"   The default relative path './data/Merged zoo.csv' only works from project root")
print(f"   From notebooks_sandbox, use: str(DATA_DIR / 'Merged zoo.csv')")

=== Importing CustomDataset and Checking Data ===
CustomDataset imported successfully
‚úÖ Merged zoo.csv found at: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/data/Merged zoo.csv

CustomDataset parameters:
- L_exp: List of experiments (label1, label2, epoch_key, activ_key)
- batch_size: Default 20
- batch_limit: Default 5
- df_path: Path to Merged zoo.csv (IMPORTANT: Must use absolute path)

Batchify function:
- Creates batches from L_exp with given batch_size and limit

Sample L_exp format: (0, 1, 0, 0)

üîß Fixing path issue - using absolute path: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/data/Merged zoo.csv
‚úÖ CustomDataset created successfully!
Epoch mapping: {'0': '11', '1': '16', '2': '21', '3': '26', '4': '31', '5': '36'}
Activation mapping: {'0': 'gelu', '1': 'relu', '2': 'silu', '3': 'leakyrelu', '4': 'sigmoid', '5': 'tanh'}
Dataset length: 2 batches

üìä CSV Structure:
   Shape: (5, 2483)
   Available columns: ['label', '0', '1', '2'

In [3]:
# Cell 3: Single Initialization Data Preparation
print("=== Single Initialization Data Preparation ===")

# Create a sample L_exp optimized for single initialization analysis
def create_single_init_L_exp(num_samples=50):
    """Create L_exp focused on single initialization scenarios"""
    L_exp = []
    
    # For single initialization, focus on primary activation and key epochs
    labels = list(range(10))
    epoch_keys = [0, 1, 2]  # Focus on early-mid training epochs
    activ_key = 0  # Focus on primary activation (will be detected automatically)
    
    for i in range(num_samples):
        # Create balanced label pairs
        label1 = labels[i % len(labels)]
        label2 = labels[(i + 1) % len(labels)]
        epoch_key = epoch_keys[i % len(epoch_keys)]
        
        L_exp.append((label1, label2, epoch_key, activ_key))
    
    return L_exp

# Create single initialization dataset
single_init_L_exp = create_single_init_L_exp(30)
print(f"Created single initialization L_exp with {len(single_init_L_exp)} samples")

# Create dataset with single initialization optimization
single_init_dataset = CustomDataset(
    L_exp=single_init_L_exp,
    batch_size=6,
    batch_limit=10,
    df_path=str(merged_zoo_path)  # Use absolute path
)

print(f"Single initialization dataset length: {len(single_init_dataset)}")
print(f"Number of batches: {len(single_init_dataset.batchs)}")

# Analyze the dataset structure
print(f"\nüìä Single Initialization Dataset Analysis:")
print(f"   Primary activation: Will be detected from data")
print(f"   Epoch range: Early-mid training phases")
print(f"   Label pairs: Balanced across 10 classes")
print(f"   Batch configuration: {single_init_dataset.batch_size} samples per batch")

=== Single Initialization Data Preparation ===
Created single initialization L_exp with 30 samples
Single initialization dataset length: 5
Number of batches: 5

üìä Single Initialization Dataset Analysis:
   Primary activation: Will be detected from data
   Epoch range: Early-mid training phases
   Label pairs: Balanced across 10 classes
   Batch configuration: 6 samples per batch


In [4]:
# Cell 3: Single Initialization Data Preparation for Training
print("=== Single Initialization Data Preparation for Training ===")

# Create a sample L_exp optimized for single initialization analysis
def create_training_L_exp(num_samples=180):  # 180 samples = 5 batches of 36
    """Create L_exp for training with batch_size=36"""
    L_exp = []
    
    # For single initialization, focus on primary activation and key epochs
    labels = list(range(10))
    epoch_keys = [0, 1, 2]  # Focus on early-mid training epochs
    activ_key = 0  # Focus on primary activation (will be detected automatically)
    
    for i in range(num_samples):
        # Create balanced label pairs for training
        label1 = labels[i % len(labels)]
        label2 = labels[(i + 1) % len(labels)]
        epoch_key = epoch_keys[i % len(epoch_keys)]
        
        L_exp.append((label1, label2, epoch_key, activ_key))
    
    return L_exp

# Create training dataset with batch_size=36 to match meta.ipynb
training_L_exp = create_training_L_exp(180)  # 180 samples for 5 batches of 36
print(f"Created training L_exp with {len(training_L_exp)} samples")

# Create dataset with training batch configuration (matching meta.ipynb)
training_dataset = CustomDataset(
    L_exp=training_L_exp,
    batch_size=36,  # CRITICAL: Match batch size from meta.ipynb
    batch_limit=5,   # Create 5 batches for demonstration
    df_path=str(merged_zoo_path)  # Use absolute path
)

print(f"Training dataset length: {len(training_dataset)}")
print(f"Number of batches: {len(training_dataset.batchs)}")

# Analyze the training dataset structure
print(f"\nüìä Training Dataset Analysis:")
print(f"   Primary activation: Will be detected from data")
print(f"   Epoch range: Early-mid training phases")
print(f"   Label pairs: Balanced across 10 classes")
print(f"   Batch configuration: {training_dataset.batch_size} samples per batch (matches meta.ipynb)")
print(f"   Total samples: {len(training_L_exp)} ‚Üí {len(training_dataset)} batches of 36 pairs each")

# Verify batch structure for training
if len(training_dataset.batchs) > 0:
    sample_batch = training_dataset.batchs[0]
    print(f"\nüîç Sample Batch Structure:")
    print(f"   Batch type: {type(sample_batch)}")
    print(f"   Batch length: {len(sample_batch) if sample_batch else 'N/A'}")
    print(f"   Expected format: 36 pairs of (1D vectors + targets)")

=== Single Initialization Data Preparation for Training ===
Created training L_exp with 180 samples
Training dataset length: 5
Number of batches: 5

üìä Training Dataset Analysis:
   Primary activation: Will be detected from data
   Epoch range: Early-mid training phases
   Label pairs: Balanced across 10 classes
   Batch configuration: 36 samples per batch (matches meta.ipynb)
   Total samples: 180 ‚Üí 5 batches of 36 pairs each

üîç Sample Batch Structure:
   Batch type: <class 'list'>
   Batch length: 36
   Expected format: 36 pairs of (1D vectors + targets)


In [5]:
# Cell 4: Setup Tensor Batch Directory Structure
import os
from pathlib import Path

# Base directory for tensor batches
BASE_TENSOR_DIR = Path("/home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches")

# Scenario-specific directories
SCENARIOS = {
    'train': BASE_TENSOR_DIR / "train_pair_scenario",
    'val': BASE_TENSOR_DIR / "val_pair_scenario", 
    'test': BASE_TENSOR_DIR / "test_pair_scenario"
}

# Create directories if they don't exist
for scenario_name, scenario_dir in SCENARIOS.items():
    scenario_dir.mkdir(parents=True, exist_ok=True)

# Current scenario for this notebook
CURRENT_SCENARIO = 'train'
TENSOR_DIR = SCENARIOS[CURRENT_SCENARIO]

print(f"Tensor batch directories organized by scenario:")
for scenario_name, scenario_dir in SCENARIOS.items():
    file_count = len(list(scenario_dir.glob("*.pt")))
    print(f"  {scenario_name}: {scenario_dir} ({file_count} files)")

print(f"Current scenario: {CURRENT_SCENARIO}")
print(f"Using directory: {TENSOR_DIR}")

Tensor batch directories organized by scenario:
  train: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches/train_pair_scenario (21 files)
  val: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches/val_pair_scenario (0 files)
  test: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches/test_pair_scenario (0 files)
Current scenario: train
Using directory: /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches/train_pair_scenario


In [6]:
# Cell 5: Create Training Dataset with Proper MNIST Subset Labels
print("=== Create Training Dataset with Proper MNIST Subset Labels ===")

def get_actual_mnist_labels():
    """Get the actual MNIST subset labels from the CSV"""
    
    try:
        import pandas as pd
        df = pd.read_csv(merged_zoo_path)
        unique_labels = df['label'].unique()
        
        print(f"Found {len(unique_labels)} unique MNIST subset labels:")
        for i, label in enumerate(unique_labels):
            print(f"  {i}: {repr(label)}")
        
        return list(unique_labels)
        
    except Exception as e:
        print(f"Error reading labels: {e}")
        return []

def create_mnist_training_L_exp(num_samples=360):
    """Create L_exp using actual MNIST subset labels from the CSV"""
    
    # Get actual labels
    mnist_labels = get_actual_mnist_labels()
    
    if not mnist_labels:
        print("No labels found - cannot create dataset")
        return []
    
    print(f"\nCreating training L_exp with {num_samples} samples using actual MNIST labels...")
    
    L_exp = []
    epoch_keys = list(range(6))  # 0-5 for different epochs
    activ_keys = list(range(6))  # 0-5 for different activations
    
    for i in range(num_samples):
        # Use actual MNIST subset labels from CSV
        label1 = mnist_labels[i % len(mnist_labels)]
        label2 = mnist_labels[(i + 1) % len(mnist_labels)]
        epoch_key = epoch_keys[i % len(epoch_keys)]
        activ_key = activ_keys[i % len(activ_keys)]
        
        L_exp.append((label1, label2, epoch_key, activ_key))
    
    print(f"Created L_exp with {len(L_exp)} samples")
    print(f"Sample entry: {L_exp[0]}")
    print(f"Label types: {type(L_exp[0][0])}, {type(L_exp[0][1])}")
    
    return L_exp

def create_mnist_training_dataset():
    """Create the training dataset with proper MNIST labels"""
    
    print(f"\n=== Creating MNIST Training Dataset ===")
    
    # Create L_exp with actual MNIST subset labels
    mnist_training_L_exp = create_mnist_training_L_exp(360)  # 10 batches of 36
    
    if not mnist_training_L_exp:
        print("‚ùå Failed to create L_exp")
        return None
    
    try:
        # Create dataset with MNIST labels
        mnist_training_dataset = CustomDataset(
            L_exp=mnist_training_L_exp,
            batch_size=36,  # Match meta.ipynb
            batch_limit=10,  # 10 batches
            df_path=str(merged_zoo_path)
        )
        
        print(f"‚úÖ MNIST training dataset created successfully!")
        print(f"   Dataset length: {len(mnist_training_dataset)} batches")
        print(f"   Expected samples per batch: 36")
        print(f"   Total samples in L_exp: {len(mnist_training_L_exp)}")
        
        # Test accessing first few batches
        print(f"\n=== Testing Batch Access ===")
        for i in range(min(3, len(mnist_training_dataset))):
            try:
                loaded, batch, L_ACC, L_indexes = mnist_training_dataset[i]
                print(f"‚úÖ Batch {i}: loaded.shape={loaded.shape}, batch_len={len(batch)}")
                
                # Show sample of actual labels being used
                if i == 0 and len(batch) > 0:
                    sample_item = batch[0]
                    print(f"   Sample batch item type: {type(sample_item)}")
                    if hasattr(sample_item, '__len__') and len(sample_item) > 0:
                        print(f"   Sample item length: {len(sample_item)}")
                
            except Exception as e:
                print(f"‚ùå Error accessing batch {i}: {e}")
                return None
        
        return mnist_training_dataset
        
    except Exception as e:
        print(f"‚ùå Error creating MNIST training dataset: {e}")
        import traceback
        traceback.print_exc()
        return None

# Create the proper MNIST training dataset
full_training_dataset = create_mnist_training_dataset()

if full_training_dataset:
    print(f"\n‚úÖ Success! Ready for batch saving in next cell")
    print(f"   Dataset has {len(full_training_dataset)} batches of 36 pairs each")
    print(f"   Using actual MNIST subset labels from the trained CNN")
else:
    print(f"\n‚ùå Failed to create training dataset")
    print(f"   Check the error messages above")

=== Create Training Dataset with Proper MNIST Subset Labels ===

=== Creating MNIST Training Dataset ===
Found 1013 unique MNIST subset labels:
  0: '[0, 1]'
  1: '[0, 2]'
  2: '[0, 3]'
  3: '[0, 4]'
  4: '[0, 5]'
  5: '[0, 6]'
  6: '[0, 7]'
  7: '[0, 8]'
  8: '[0, 9]'
  9: '[1, 2]'
  10: '[1, 3]'
  11: '[1, 4]'
  12: '[1, 5]'
  13: '[1, 6]'
  14: '[1, 7]'
  15: '[1, 8]'
  16: '[1, 9]'
  17: '[2, 3]'
  18: '[2, 4]'
  19: '[2, 5]'
  20: '[2, 6]'
  21: '[2, 7]'
  22: '[2, 8]'
  23: '[2, 9]'
  24: '[3, 4]'
  25: '[3, 5]'
  26: '[3, 6]'
  27: '[3, 7]'
  28: '[3, 8]'
  29: '[3, 9]'
  30: '[4, 5]'
  31: '[4, 6]'
  32: '[4, 7]'
  33: '[4, 8]'
  34: '[4, 9]'
  35: '[5, 6]'
  36: '[5, 7]'
  37: '[5, 8]'
  38: '[5, 9]'
  39: '[6, 7]'
  40: '[6, 8]'
  41: '[6, 9]'
  42: '[7, 8]'
  43: '[7, 9]'
  44: '[8, 9]'
  45: '[0, 1, 2]'
  46: '[0, 1, 3]'
  47: '[0, 1, 4]'
  48: '[0, 1, 5]'
  49: '[0, 1, 6]'
  50: '[0, 1, 7]'
  51: '[0, 1, 8]'
  52: '[0, 1, 9]'
  53: '[0, 2, 3]'
  54: '[0, 2, 4]'
  55: '[0, 

In [7]:
# Cell 6: Test MNIST Label Parsing and Save Batches
print("=== Test MNIST Label Parsing and Save Batches ===")

# Reload the updated CustomDataset
exec(open(ROOT / "Double_input_transformer.py").read())

def test_mnist_label_parsing():
    """Test that MNIST label parsing works correctly"""
    
    print("=== Testing MNIST Label Parsing ===")
    
    # Test the parsing function directly
    def parse_mnist_label(label_str):
        """Parse MNIST subset label string like '[0, 1]' into list [0, 1]"""
        try:
            import ast
            parsed = ast.literal_eval(label_str)
            return list(parsed) if isinstance(parsed, (list, tuple)) else [parsed]
        except:
            # Fallback: try to extract numbers manually
            import re
            numbers = re.findall(r'\d+', str(label_str))
            return [int(n) for n in numbers] if numbers else [0]
    
    # Test with sample labels
    test_labels = ["[0, 1]", "[2, 3, 4]", "[5, 6, 7, 8]"]
    for label in test_labels:
        parsed = parse_mnist_label(label)
        print(f"  {label} ‚Üí {parsed}")
    
    print("‚úÖ Label parsing function works correctly")

def save_training_batches_working(dataset, output_dir, prefix="training"):
    """Save training batches with the fixed label parsing"""
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    saved_files = {
        'loaded_tensors': [],
        'metadata': [],
        'batch_info': []
    }
    
    print(f"Processing {len(dataset)} training batches (each with 36 pairs)...")
    
    for i in tqdm(range(len(dataset)), desc="Saving training batches"):
        try:
            # Get batch data - this should work now with fixed label parsing
            loaded, batch, L_ACC, L_indexes = dataset[i]
            
            # Verify batch size for training
            expected_batch_size = 36
            actual_batch_size = len(batch) if hasattr(batch, '__len__') else 0
            
            if actual_batch_size != expected_batch_size:
                print(f"‚ö†Ô∏è Warning: Batch {i} has {actual_batch_size} samples, expected {expected_batch_size}")
            
            # Save main tensor
            tensor_filename = f"{prefix}_batch_{i:03d}_loaded.pt"
            tensor_path = output_path / tensor_filename
            torch.save(loaded, tensor_path)
            saved_files['loaded_tensors'].append(str(tensor_path))
            
            # Save metadata with proper serialization
            metadata = {
                'batch_idx': int(i),
                'batch_info': str(batch) if not isinstance(batch, (list, tuple)) else list(batch),
                'L_ACC': float(L_ACC) if isinstance(L_ACC, (int, float)) else str(L_ACC),
                'L_indexes': list(L_indexes) if hasattr(L_indexes, '__iter__') else [int(L_indexes)],
                'tensor_shape': tuple(loaded.shape) if hasattr(loaded, 'shape') else str(loaded.shape),
                'tensor_filename': str(tensor_filename),
                'batch_size': int(actual_batch_size),
                'expected_batch_size': int(expected_batch_size),
                'is_training_batch': True
            }
            
            metadata_filename = f"{prefix}_batch_{i:03d}_metadata.pt"
            metadata_path = output_path / metadata_filename
            torch.save(metadata, metadata_path)
            saved_files['metadata'].append(str(metadata_path))
            
            # Store batch info
            saved_files['batch_info'].append({
                'batch_idx': int(i),
                'batch_size': int(actual_batch_size),
                'tensor_shape': tuple(loaded.shape) if hasattr(loaded, 'shape') else str(loaded.shape),
                'is_training_ready': actual_batch_size == expected_batch_size
            })
            
            if i < 2:
                print(f"  ‚úì Successfully saved batch {i}: {actual_batch_size} pairs")
            
        except Exception as e:
            print(f"Error saving training batch {i}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Save summary
    summary = {
        'total_batches': len(saved_files['loaded_tensors']),
        'batch_files': saved_files['loaded_tensors'],
        'metadata_files': saved_files['metadata'],
        'batch_info': saved_files['batch_info'],
        'dataset_config': {
            'L_exp_length': len(dataset.L_exp),
            'batch_size': int(dataset.batch_size),
            'batch_limit': int(dataset.batch_limit),
            'df_path': str(dataset.df_path),
            'target_batch_size': 36,
            'is_training_dataset': True
        },
        'training_info': {
            'pairs_per_batch': 36,
            'total_pairs': len(saved_files['loaded_tensors']) * 36,
            'ready_for_meta_training': len(saved_files['loaded_tensors']) > 0
        }
    }
    
    summary_path = output_path / f"{prefix}_summary.pt"
    torch.save(summary, summary_path)
    
    print(f"Saved {len(saved_files['loaded_tensors'])} training batches to {output_path}")
    return summary

# Test label parsing
test_mnist_label_parsing()

# Save the training batches
if full_training_dataset:
    print(f"\n=== Saving Training Batches ===")
    try:
        training_summary = save_training_batches_working(full_training_dataset, TENSOR_DIR, "training")
        
        if training_summary['total_batches'] > 0:
            print(f"‚úÖ Successfully saved training batches!")
            print(f"   Total batches: {training_summary['total_batches']}")
            print(f"   Pairs per batch: 36 (matches meta.ipynb)")
            print(f"   Total pairs: {training_summary['training_info']['total_pairs']}")
            print(f"   Ready for meta training: {training_summary['training_info']['ready_for_meta_training']}")
        else:
            print(f"‚ùå No batches were saved successfully")
            
    except Exception as e:
        print(f"‚ùå Error saving training batches: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ùå No training dataset available - run cell 5 first")

=== Test MNIST Label Parsing and Save Batches ===
=== Testing MNIST Label Parsing ===
  [0, 1] ‚Üí [0, 1]
  [2, 3, 4] ‚Üí [2, 3, 4]
  [5, 6, 7, 8] ‚Üí [5, 6, 7, 8]
‚úÖ Label parsing function works correctly

=== Saving Training Batches ===
Processing 10 training batches (each with 36 pairs)...


Saving training batches:  10%|‚ñà         | 1/10 [00:00<00:02,  4.32it/s]

  ‚úì Successfully saved batch 0: 36 pairs


Saving training batches:  20%|‚ñà‚ñà        | 2/10 [00:00<00:01,  4.33it/s]

  ‚úì Successfully saved batch 1: 36 pairs


Saving training batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:02<00:00,  4.26it/s]

Saved 10 training batches to /home/aymen/Documents/GitHub/Federated-Continual-learning-/New/notebooks_sandbox/tensor_batches/train_pair_scenario
‚úÖ Successfully saved training batches!
   Total batches: 10
   Pairs per batch: 36 (matches meta.ipynb)
   Total pairs: 360
   Ready for meta training: True





In [8]:
# Cell 7: Create Training TensorDataset for Double Input Transformer
def create_training_tensor_dataset(summary_path):
    """Create TensorDataset from saved training batches for double input transformer"""
    
    try:
        # Add safe globals for PyTorch 2.6 compatibility
        import numpy.core.multiarray
        torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
        
        summary = torch.load(summary_path, map_location='cpu', weights_only=False)
    except Exception as e:
        print(f"Error loading summary: {e}")
        return None, None, None
    
    if summary['total_batches'] == 0:
        print("No batches were saved")
        return None, None, None
    
    print(f"Loading {summary['total_batches']} training batches...")
    
    # Load all batch tensors
    all_loaded_tensors = []
    all_metadata = []
    
    for i, (tensor_file, metadata_file) in enumerate(zip(
        summary['batch_files'], 
        summary['metadata_files']
    )):
        try:
            loaded = torch.load(tensor_file, map_location='cpu', weights_only=False)
            metadata = torch.load(metadata_file, map_location='cpu', weights_only=False)
            
            expected_size = 36
            actual_size = metadata.get('batch_size', 0)
            if actual_size != expected_size:
                print(f"Batch {i}: size {actual_size} != expected {expected_size}")
            
            all_loaded_tensors.append(loaded)
            all_metadata.append(metadata)
                
        except Exception as e:
            print(f"Error loading batch {i}: {e}")
            continue
    
    if not all_loaded_tensors:
        print("No tensors loaded successfully")
        return None, None, None
    
    # Stack all tensors for TensorDataset
    stacked_tensors = torch.stack(all_loaded_tensors)  # [10, 36, 3, 2464]
    feature_dim = stacked_tensors.shape[-1]  # 2464
    
    print(f"Stacked tensors shape: {stacked_tensors.shape}")
    print(f"Feature dimension: {feature_dim}")
    
    # Reshape for easier processing: [10*36, 3, 2464]
    all_samples = stacked_tensors.view(-1, 3, feature_dim)  # [360, 3, 2464]
    
    # Extract inputs and targets for double input transformer
    stream1_features = all_samples[:, 0, :]  # [360, 2464] - parent1 weight vectors
    stream2_features = all_samples[:, 1, :]  # [360, 2464] - parent2 weight vectors  
    target_vectors = all_samples[:, 2, :]    # [360, 2464] - target weight vectors
    
    print(f"Stream 1 shape: {stream1_features.shape}")
    print(f"Stream 2 shape: {stream2_features.shape}")
    print(f"Target vectors shape: {target_vectors.shape}")
    print(f"Target range: [{target_vectors.min():.3f}, {target_vectors.max():.3f}]")
    
    # Create TensorDataset for double input transformer
    training_tensor_dataset = torch.utils.data.TensorDataset(
        stream1_features, 
        stream2_features, 
        target_vectors
    )
    
    dataset_info = {
        'total_samples': len(training_tensor_dataset),
        'stream1_dim': stream1_features.shape[1],
        'stream2_dim': stream2_features.shape[1],
        'target_dim': target_vectors.shape[1],
        'original_batches': len(all_loaded_tensors),
        'samples_per_batch': 36,
        'feature_dim': feature_dim,
        'dataset_type': 'double_input_transformer_full_vectors',
        'scenario': CURRENT_SCENARIO,
        'target_vector_stats': {
            'min': float(target_vectors.min()),
            'max': float(target_vectors.max()),
            'mean': float(target_vectors.mean()),
            'std': float(target_vectors.std())
        }
    }
    
    return training_tensor_dataset, dataset_info, summary

def create_tensor_dataset_from_custom(custom_dataset):
    """Create TensorDataset directly from CustomDataset if file loading fails"""
    
    if custom_dataset is None:
        print("CustomDataset not available")
        return None, None
    
    try:
        all_stream1_features = []
        all_stream2_features = []
        all_target_vectors = []
        
        print("Creating TensorDataset from CustomDataset...")
        
        for i in range(len(custom_dataset)):
            try:
                loaded, batch, L_ACC, L_indexes = custom_dataset[i]
                
                if loaded.dim() == 3 and loaded.shape[1] == 3:
                    feature_dim = loaded.shape[-1]  # 2464
                    
                    for j in range(loaded.shape[0]):  # 36 samples
                        stream1_feat = loaded[j, 0, :]  # Full 2464-dim parent1 vector
                        stream2_feat = loaded[j, 1, :]  # Full 2464-dim parent2 vector
                        target_vector = loaded[j, 2, :]  # Full 2464-dim target vector
                        
                        all_stream1_features.append(stream1_feat)
                        all_stream2_features.append(stream2_feat)
                        all_target_vectors.append(target_vector)
                    
            except Exception as e:
                print(f"Error processing batch {i}: {e}")
                continue
        
        if not all_stream1_features:
            print("No samples collected from CustomDataset")
            return None, None
        
        # Convert to tensors
        stream1_features = torch.stack(all_stream1_features)
        stream2_features = torch.stack(all_stream2_features)
        target_vectors = torch.stack(all_target_vectors)
        
        print(f"Collected {len(all_stream1_features)} samples")
        print(f"Target range: [{target_vectors.min():.3f}, {target_vectors.max():.3f}]")
        
        # Create TensorDataset
        training_tensor_dataset = torch.utils.data.TensorDataset(
            stream1_features, 
            stream2_features, 
            target_vectors
        )
        
        dataset_info = {
            'total_samples': len(training_tensor_dataset),
            'stream1_dim': stream1_features.shape[1],
            'stream2_dim': stream2_features.shape[1],
            'target_dim': target_vectors.shape[1],
            'original_batches': len(custom_dataset),
            'samples_per_batch': 36,
            'feature_dim': stream1_features.shape[1],
            'dataset_type': 'double_input_transformer_full_vectors_from_custom',
            'scenario': CURRENT_SCENARIO,
            'target_vector_stats': {
                'min': float(target_vectors.min()),
                'max': float(target_vectors.max()),
                'mean': float(target_vectors.mean()),
                'std': float(target_vectors.std())
            }
        }
        
        return training_tensor_dataset, dataset_info
        
    except Exception as e:
        print(f"Error creating TensorDataset from CustomDataset: {e}")
        return None, None

# Create the TensorDataset
try:
    training_summary_path = TENSOR_DIR / "training_summary.pt"
    
    if not training_summary_path.exists():
        print(f"Summary file not found: {training_summary_path}")
        print("Run cell 6 first to create the summary file")
        training_tensor_dataset = None
    else:
        # Try to load from saved files first
        training_tensor_dataset, dataset_info, training_summary = create_training_tensor_dataset(training_summary_path)
        
        # If that fails, try creating directly from CustomDataset
        if training_tensor_dataset is None:
            print("Loading from saved files failed, trying direct approach...")
            training_tensor_dataset, dataset_info = create_tensor_dataset_from_custom(full_training_dataset)
        
        if training_tensor_dataset:
            print(f"TensorDataset created successfully: {len(training_tensor_dataset)} samples")
            print(f"Scenario: {dataset_info['scenario']}")
            
            # Test accessing a sample
            sample_stream1, sample_stream2, sample_target = training_tensor_dataset[0]
            print(f"Sample shapes: stream1={sample_stream1.shape}, stream2={sample_stream2.shape}, target={sample_target.shape}")
        else:
            print("Failed to create TensorDataset")
            
except Exception as e:
    print(f"Error creating TensorDataset: {e}")
    training_tensor_dataset = None

Loading 10 training batches...
Stacked tensors shape: torch.Size([10, 36, 3, 2464])
Feature dimension: 2464
Stream 1 shape: torch.Size([360, 2464])
Stream 2 shape: torch.Size([360, 2464])
Target vectors shape: torch.Size([360, 2464])
Target range: [-5.607, 4.186]
TensorDataset created successfully: 360 samples
Scenario: train
Sample shapes: stream1=torch.Size([2464]), stream2=torch.Size([2464]), target=torch.Size([2464])


In [9]:
# Cell 8: Benchmark and Create Toy Double Input Transformer Model
import torch.nn as nn
import torch.nn.functional as F
import time

class ToyDoubleInputTransformer(nn.Module):
    """Toy Double Input Transformer matching TransformerAE architecture for testing"""
    
    def __init__(self, input_dim=2464, d_model=100, max_seq_len=50, neck=20, N=1, heads=1, d_ff=100):
        super().__init__()
        
        self.input_dim = input_dim
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.neck = neck
        self.N = N
        self.heads = heads
        self.d_ff = d_ff
        
        # Simplified embedding layer (mimics EmbedderNeuronGroup)
        self.embed_stream1 = nn.Linear(input_dim, max_seq_len * d_model)
        self.embed_stream2 = nn.Linear(input_dim, max_seq_len * d_model)
        
        # Simplified positional encoding
        self.pos_encoder = nn.Parameter(torch.randn(1, max_seq_len, d_model) * 0.1)
        
        # Simplified encoder layers (one for each stream)
        self.encoder_stream1 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=heads, dim_feedforward=d_ff, dropout=0.1),
            num_layers=N
        )
        self.encoder_stream2 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=heads, dim_feedforward=d_ff, dropout=0.1),
            num_layers=N
        )
        
        # Fusion layer (mimics vec2neck in TransformerAE)
        self.fusion = nn.Linear(2 * d_model * max_seq_len, neck)
        
        # Simplified decoder (mimics DecoderNeuronGroup + Seq2Vec)
        self.decoder = nn.Sequential(
            nn.Linear(neck, d_ff),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_ff, input_dim)  # Output 2464-dimensional vector
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def forward(self, stream1, stream2):
        """
        Args:
            stream1: [batch_size, 2464] - parent model weights
            stream2: [batch_size, 2464] - parent model weights
        Returns:
            output: [batch_size, 2464] - predicted target model weights
        """
        batch_size = stream1.size(0)
        
        # Embed streams to sequences
        stream1_seq = self.embed_stream1(stream1).view(batch_size, self.max_seq_len, self.d_model)
        stream2_seq = self.embed_stream2(stream2).view(batch_size, self.max_seq_len, self.d_model)
        
        # Add positional encoding
        stream1_seq = stream1_seq + self.pos_encoder
        stream2_seq = stream2_seq + self.pos_encoder
        
        # Encode each stream independently
        stream1_encoded = self.encoder_stream1(stream1_seq.transpose(0, 1)).transpose(0, 1)
        stream2_encoded = self.encoder_stream2(stream2_seq.transpose(0, 1)).transpose(0, 1)
        
        # Concatenate encoded streams
        concatenated = torch.cat([stream1_encoded, stream2_encoded], dim=2)
        
        # Flatten and fuse
        flattened = concatenated.view(batch_size, -1)
        neck_rep = torch.tanh(self.fusion(flattened))
        
        # Decode to final output
        output = self.decoder(neck_rep)
        
        return output

def benchmark_datasets(training_dataset, training_tensor_dataset):
    """Benchmark CustomDataset vs TensorDataset"""
    
    if training_dataset is None or training_tensor_dataset is None:
        print("One or both datasets not available")
        return None
    
    # Test CustomDataset
    custom_times = []
    custom_loader = torch.utils.data.DataLoader(training_dataset, batch_size=1)
    
    start_time = time.time()
    for i, batch in enumerate(custom_loader):
        if i >= 5:
            break
        loaded, batch_data, L_ACC, L_indexes = batch
        custom_times.append(time.time() - start_time)
        start_time = time.time()
    
    custom_avg_time = sum(custom_times) / len(custom_times) if custom_times else 0
    
    # Test TensorDataset  
    tensor_times = []
    tensor_loader = torch.utils.data.DataLoader(training_tensor_dataset, batch_size=1)
    
    start_time = time.time()
    for i, (stream1, stream2, target_vector) in enumerate(tensor_loader):
        if i >= 5:
            break
        tensor_times.append(time.time() - start_time)
        start_time = time.time()
    
    tensor_avg_time = sum(tensor_times) / len(tensor_times) if tensor_times else 0
    
    print(f"CustomDataset avg time: {custom_avg_time:.4f}s")
    print(f"TensorDataset avg time: {tensor_avg_time:.4f}s")
    
    if tensor_avg_time > 0:
        speedup = custom_avg_time / tensor_avg_time
        print(f"TensorDataset speedup: {speedup:.2f}x")
    
    return {
        'custom_avg_time': custom_avg_time,
        'tensor_avg_time': tensor_avg_time,
        'speedup': custom_avg_time / tensor_avg_time if tensor_avg_time > 0 else 0
    }

def test_toy_model(training_tensor_dataset):
    """Test the toy double input transformer model with 2464-dim vectors"""
    
    if training_tensor_dataset is None:
        print("TensorDataset not available")
        return None
    
    # Get dataset info
    sample_stream1, sample_stream2, sample_target = training_tensor_dataset[0]
    stream1_dim = sample_stream1.shape[0]
    stream2_dim = sample_stream2.shape[0]
    target_dim = sample_target.shape[0]  # Should be 2464
    
    print(f"Dataset dimensions: stream1={stream1_dim}, stream2={stream2_dim}, target={target_dim}")
    
    if target_dim != 2464:
        print(f"Warning: Expected target_dim=2464, got {target_dim}")
    
    # Create model matching TransformerAE architecture
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ToyDoubleInputTransformer(
        input_dim=stream1_dim,  # 2464
        d_model=100,            # Match TransformerAE default
        max_seq_len=50,         # Match EmbedderNeuronGroup output (26+24=50)
        neck=20,                # Match TransformerAE default
        N=1,                    # Match TransformerAE default
        heads=1,                # Match TransformerAE default
        d_ff=100                # Match TransformerAE default
    ).to(device)
    
    print(f"Model created on device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Test forward pass
    model.eval()
    with torch.no_grad():
        # Create DataLoader
        test_loader = torch.utils.data.DataLoader(training_tensor_dataset, batch_size=4, shuffle=True)
        
        for i, (stream1_batch, stream2_batch, target_batch) in enumerate(test_loader):
            if i >= 2:
                break
                
            # Move to device
            stream1_batch = stream1_batch.to(device)
            stream2_batch = stream2_batch.to(device)
            target_batch = target_batch.to(device)
            
            # Forward pass
            output = model(stream1_batch, stream2_batch)
            
            print(f"Batch {i}:")
            print(f"  Input shapes: stream1={stream1_batch.shape}, stream2={stream2_batch.shape}")
            print(f"  Target shape: {target_batch.shape}")
            print(f"  Output shape: {output.shape}")
            print(f"  Target range: [{target_batch.min():.3f}, {target_batch.max():.3f}]")
            print(f"  Output range: [{output.min():.3f}, {output.max():.3f}]")
            
            # Calculate MSE for this batch
            mse = F.mse_loss(output, target_batch)
            print(f"  MSE: {mse.item():.6f}")
    
    return model

# Run benchmarks and tests
try:
    # Benchmark datasets
    benchmark_results = benchmark_datasets(full_training_dataset, training_tensor_dataset)
    
    # Test toy model
    toy_model = test_toy_model(training_tensor_dataset)
    
    if toy_model:
        print("All tests completed successfully")
        print("TensorDataset ready for training with 2464-dim vectors")
        print("Toy double input transformer matching TransformerAE architecture")
    else:
        print("Model test failed")
        
except Exception as e:
    print(f"Error in benchmarking/testing: {e}")
    import traceback
    traceback.print_exc()

CustomDataset avg time: 0.2308s
TensorDataset avg time: 0.0001s
TensorDataset speedup: 3371.02x
Dataset dimensions: stream1=2464, stream2=2464, target=2464
Model created on device: cuda
Model parameters: 25,227,984
Batch 0:
  Input shapes: stream1=torch.Size([4, 2464]), stream2=torch.Size([4, 2464])
  Target shape: torch.Size([4, 2464])
  Output shape: torch.Size([4, 2464])
  Target range: [-1.402, 1.086]
  Output range: [-0.343, 0.380]
  MSE: 0.071182
Batch 1:
  Input shapes: stream1=torch.Size([4, 2464]), stream2=torch.Size([4, 2464])
  Target shape: torch.Size([4, 2464])
  Output shape: torch.Size([4, 2464])
  Target range: [-4.644, 3.030]
  Output range: [-0.384, 0.360]
  MSE: 0.244562
All tests completed successfully
TensorDataset ready for training with 2464-dim vectors
Toy double input transformer matching TransformerAE architecture


In [10]:
# Cell 9: Training Loop Demo with Double Input Transformer
def train_double_input_transformer(model, dataset, epochs=5, batch_size=32, learning_rate=1e-4):
    """Training loop for double input transformer"""
    
    if model is None or dataset is None:
        print("Model or dataset not available")
        return None, None
    
    device = next(model.parameters()).device
    print(f"Training on device: {device}")
    
    # Create DataLoader
    train_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True,
        drop_last=True
    )
    
    print(f"Training samples: {len(dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Batches per epoch: {len(train_loader)}")
    
    # Setup training
    criterion = nn.MSELoss()  # Regression loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    
    # Training history
    train_losses = []
    
    print("Starting training...")
    
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, (stream1_batch, stream2_batch, target_batch) in enumerate(train_loader):
            # Move to device
            stream1_batch = stream1_batch.to(device)
            stream2_batch = stream2_batch.to(device)
            target_batch = target_batch.to(device).float()
            
            # Forward pass
            optimizer.zero_grad()
            output = model(stream1_batch, stream2_batch)
            
            # Calculate loss
            loss = criterion(output.squeeze(), target_batch.squeeze())
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
            
            # Print progress
            if batch_idx % 2 == 0:
                print(f"  Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.6f}")
        
        # Calculate epoch metrics
        avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
        train_losses.append(avg_epoch_loss)
        
        # Learning rate scheduling
        scheduler.step(avg_epoch_loss)
        
        # Print epoch summary
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{epochs} completed:")
        print(f"  Average Loss: {avg_epoch_loss:.6f}")
        print(f"  Learning Rate: {current_lr:.6f}")
        print(f"  Batches Processed: {num_batches}")
        print("-" * 50)
    
    print("Training completed")
    return train_losses, model

def evaluate_model(model, dataset, batch_size=32):
    """Evaluate the trained model"""
    
    if model is None or dataset is None:
        print("Model or dataset not available")
        return None
    
    device = next(model.parameters()).device
    model.eval()
    
    # Create DataLoader
    eval_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    total_loss = 0.0
    num_samples = 0
    predictions = []
    actuals = []
    
    criterion = nn.MSELoss()
    
    print("Model evaluation...")
    
    with torch.no_grad():
        for batch_idx, (stream1_batch, stream2_batch, target_batch) in enumerate(eval_loader):
            # Move to device
            stream1_batch = stream1_batch.to(device)
            stream2_batch = stream2_batch.to(device)
            target_batch = target_batch.to(device).float()
            
            # Forward pass
            output = model(stream1_batch, stream2_batch)
            
            # Calculate loss
            loss = criterion(output.squeeze(), target_batch.squeeze())
            total_loss += loss.item() * target_batch.size(0)
            num_samples += target_batch.size(0)
            
            # Store predictions and actuals
            predictions.extend(output.squeeze().cpu().numpy())
            actuals.extend(target_batch.squeeze().cpu().numpy())
            
            if batch_idx < 2:
                print(f"Batch {batch_idx}:")
                print(f"  Target range: [{target_batch.min():.3f}, {target_batch.max():.3f}]")
                print(f"  Prediction range: [{output.min():.3f}, {output.max():.3f}]")
                print(f"  Loss: {loss.item():.6f}")
    
    # Calculate metrics
    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    
    predictions = np.array(predictions)
    actuals = np.array(actuals)
    
    # Calculate correlation
    if len(predictions) > 1 and len(actuals) > 1:
        correlation = np.corrcoef(predictions, actuals)[0, 1]
        mae = np.mean(np.abs(predictions - actuals))
        mse = np.mean((predictions - actuals) ** 2)
        rmse = np.sqrt(mse)
    else:
        correlation = 0.0
        mae = 0.0
        mse = 0.0
        rmse = 0.0
    
    print("Evaluation results:")
    print(f"  Total samples: {num_samples}")
    print(f"  Average Loss (MSE): {avg_loss:.6f}")
    print(f"  Root Mean Square Error: {rmse:.6f}")
    print(f"  Mean Absolute Error: {mae:.6f}")
    print(f"  Correlation: {correlation:.6f}")
    
    return {
        'avg_loss': avg_loss,
        'rmse': rmse,
        'mae': mae,
        'correlation': correlation,
        'num_samples': num_samples
    }

# Run training and evaluation
try:
    if training_tensor_dataset and toy_model:
        # Train the model
        train_losses, trained_model = train_double_input_transformer(
            toy_model, 
            training_tensor_dataset, 
            epochs=5, 
            batch_size=32, 
            learning_rate=1e-4
        )
        
        # Evaluate the model
        eval_results = evaluate_model(trained_model, training_tensor_dataset, batch_size=32)
        
        print("Final summary:")
        print(f"  Created training batches with 36 pairs each (matches meta.ipynb)")
        print(f"  Exported training tensors to scenario: {CURRENT_SCENARIO}")
        print(f"  Built and trained toy double input transformer model")
        print(f"  Demonstrated end-to-end training pipeline")
        print(f"  Ready for integration with meta.ipynb training pipeline")
        
        if eval_results:
            print(f"  Model evaluation completed with correlation: {eval_results['correlation']:.4f}")
        
    else:
        print("Cannot run training - missing TensorDataset or model")
        print(f"  training_tensor_dataset available: {training_tensor_dataset is not None}")
        print(f"  toy_model available: {toy_model is not None}")
        
except Exception as e:
    print(f"Error during training/evaluation: {e}")
    import traceback
    traceback.print_exc()

Training on device: cuda:0
Training samples: 360
Batch size: 32
Batches per epoch: 11
Starting training...
  Epoch 1/5, Batch 0/11, Loss: 0.202068
  Epoch 1/5, Batch 2/11, Loss: 0.186866
  Epoch 1/5, Batch 4/11, Loss: 0.165638
  Epoch 1/5, Batch 6/11, Loss: 0.133759
  Epoch 1/5, Batch 8/11, Loss: 0.176621
  Epoch 1/5, Batch 10/11, Loss: 0.170474
Epoch 1/5 completed:
  Average Loss: 0.185348
  Learning Rate: 0.000100
  Batches Processed: 11
--------------------------------------------------
  Epoch 2/5, Batch 0/11, Loss: 0.240868
  Epoch 2/5, Batch 2/11, Loss: 0.141679
  Epoch 2/5, Batch 4/11, Loss: 0.155945
  Epoch 2/5, Batch 6/11, Loss: 0.230693
  Epoch 2/5, Batch 8/11, Loss: 0.181479
  Epoch 2/5, Batch 10/11, Loss: 0.180808
Epoch 2/5 completed:
  Average Loss: 0.184185
  Learning Rate: 0.000100
  Batches Processed: 11
--------------------------------------------------
  Epoch 3/5, Batch 0/11, Loss: 0.170619
  Epoch 3/5, Batch 2/11, Loss: 0.189560
  Epoch 3/5, Batch 4/11, Loss: 0.2006