# Training

### Unified Reasoning Module

In [None]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.notebook import tqdm
from task import Task
from unified_module import UnifiedReasoningModule
import random
import torch.multiprocessing as mp
import traceback
from trainer import train_model, load_precomputed_tasks

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
MODEL_DIR = "output/models/unified_shape"
DATA_DIR = "precomputed_tasks/training"
METRICS_DIR = "output/metrics"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Main function
def main():
    # Load tasks
    print("Loading tasks...")
    tasks = load_precomputed_tasks(DATA_DIR)
    print(f"Loaded {len(tasks)} tasks")
    
    try:
        # Initialize model
        print("Initializing unified model...")
        model = UnifiedReasoningModule(
            input_dim=3,
            hidden_dim=40,
            output_dim=11,
            device=device
        )
        model.model = model.model.to(device)
        
        # Train model with task-aware approach
        print("Starting unified model training...")
        trained_model, history = train_model(
            model=model,
            tasks=tasks,
            num_epochs=10,
            learning_rate=0.02,
            weight_decay=1e-5,
            save_dir=MODEL_DIR,
            model_name="unified_shape",
            batch_size=8,
            device=device
        )
        
        print("Training complete!")
        return trained_model, history
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None, None

# Run the main function when executed
if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    model, history = main()

In [None]:
import os
import json
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from task import Task
from unified_module import UnifiedReasoningModule
import random
import traceback
from meta_learning import run_meta_learning, MAMLTrainer, ProtoNetTrainer

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
pretrained_model_path = "output/models/unified_reasoning_model_final.pt"
MODEL_DIR = "output/models/unified"
TRAIN_DIR = "data/training"
VAL_DIR = "data/evaluation"
METRICS_DIR = "output/metrics/unified_meta"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Load tasks
def load_tasks(directory):
    """Load tasks from directory"""
    tasks = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".json"):
                file_path = os.path.join(root, file)
                with open(file_path, "r") as f:
                    data = json.load(f)
                    if "train" not in data or "test" not in data:
                        print(f"Warning: Invalid task format in {file_path}")
                        continue
                    
                    task = Task(
                        task_id=os.path.basename(file_path),
                        train_pairs=[(pair["input"], pair["output"]) for pair in data["train"]],
                        test_pairs=[(pair["input"], pair["output"]) for pair in data["test"]],
                    )
                    tasks.append(task)
    return tasks

# Split tasks into training and validation sets
def split_tasks(tasks, val_ratio=0.2):
    """Split tasks into training and validation sets"""
    random.shuffle(tasks)
    val_size = int(len(tasks) * val_ratio)
    train_tasks = tasks[val_size:]
    val_tasks = tasks[:val_size]
    return train_tasks, val_tasks

# Train meta-learning models
def train_meta_learning(model, train_tasks, val_tasks, method="both"):
    """
    Train meta-learning models.
    
    Args:
        model: UnifiedReasoningModule instance
        train_tasks: List of training tasks
        val_tasks: List of validation tasks
        method: Which method to use ("maml", "proto", or "both")
        
    Returns:
        Dictionary of trained models
    """
    trained_models = {}
    
    if method.lower() == "maml" or method.lower() == "both":
        print("Training MAML model...")
        maml_trainer = run_meta_learning(
            unified_model=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="maml",
            epochs=50,
            lr=0.001,
            weight_decay=1e-5,
            log_dir=os.path.join(MODEL_DIR, "maml")
        )
        trained_models["maml"] = maml_trainer
    
    if method.lower() == "proto" or method.lower() == "both":
        print("Training Prototypical Networks model...")
        proto_trainer = run_meta_learning(
            unified_model=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="proto",
            epochs=50,
            lr=0.001,
            weight_decay=1e-5,
            log_dir=os.path.join(MODEL_DIR, "proto")
        )
        trained_models["proto"] = proto_trainer
    
    return trained_models

# Test on new tasks
def test_meta_learning(trainer, test_tasks):
    """
    Test meta-learning models on new tasks.
    
    Args:
        trainer: Meta-learning trainer instance
        test_tasks: List of test tasks
        
    Returns:
        Dictionary of test metrics
    """
    print(f"Testing {trainer.__class__.__name__} on {len(test_tasks)} tasks...")
    
    task_metrics = []
    avg_accuracy = 0.0
    
    for i, task in enumerate(test_tasks):
        try:
            # Prepare support and query sets
            support_graphs = task.train_graphs
            query_graphs = task.test_graphs
            
            if not support_graphs or not query_graphs:
                print(f"Skipping task {i} due to empty graphs")
                continue
                
            # Create batches
            support_batch = torch.geometric.data.Batch.from_data_list(support_graphs).to(device)
            query_batch = torch.geometric.data.Batch.from_data_list(query_graphs).to(device)
            
            # Evaluate with few-shot adaptation
            accuracy = trainer.validate_task(support_batch, query_batch)
            
            print(f"Task {i} accuracy: {accuracy:.4f}")
            task_metrics.append({
                "task_id": task.task_id,
                "accuracy": accuracy,
                "num_train": len(support_graphs),
                "num_test": len(query_graphs)
            })
            
            avg_accuracy += accuracy
            
        except Exception as e:
            print(f"Error testing task {i}: {e}")
            traceback.print_exc()
    
    avg_accuracy /= len(task_metrics) if task_metrics else 1.0
    print(f"Average test accuracy: {avg_accuracy:.4f}")
    
    # Save metrics
    method_name = trainer.__class__.__name__.replace("Trainer", "").lower()
    metrics_path = os.path.join(METRICS_DIR, f"{method_name}_test_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump({
            "avg_accuracy": avg_accuracy,
            "task_metrics": task_metrics
        }, f, indent=2)
    
    return {
        "avg_accuracy": avg_accuracy,
        "task_metrics": task_metrics
    }

# Main function
def main():
    # Load tasks
    print("Loading tasks...")
    train_tasks = load_tasks(TRAIN_DIR)
    val_tasks = load_tasks(VAL_DIR)
    print(f"Loaded {len(train_tasks)+len(val_tasks)} tasks")
    
    # Split into train, validation, and test sets
    val_tasks, test_tasks = split_tasks(val_tasks, val_ratio=0.5)
    print(f"Split into {len(train_tasks)} training, {len(val_tasks)} validation, and {len(test_tasks)} test tasks")
    
    try:
        # Initialize model
        print("Initializing unified model...")
        model = UnifiedReasoningModule(
            input_dim=3,
            hidden_dim=128,
            output_dim=11,
            device=device
        )
        model.load_complete_state(pretrained_model_path)
        model.model = model.model.to(device)
        
        # Train meta-learning models
        print("Training meta-learning models...")
        trained_models = train_meta_learning(
            model=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="both"  # Train both MAML and Prototypical Networks
        )
        
        # Test meta-learning models
        print("Testing meta-learning models...")
        test_results = {}
        for method, trainer in trained_models.items():
            test_results[method] = test_meta_learning(trainer, test_tasks)
        
        # Compare results
        print("\nComparison of meta-learning methods:")
        for method, results in test_results.items():
            print(f"{method.upper()}: Average accuracy = {results['avg_accuracy']:.4f}")
        
        print("Meta-learning training and evaluation complete!")
        return trained_models, test_results
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None, None

# Run the main function when executed
if __name__ == "__main__":
    models, results = main()

### NLM Reasoning Module

In [None]:
import os
import json
import numpy as np
import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from task import Task
from nlm_module import NLMReasoningModule
import random
import torch.multiprocessing as mp
import traceback
from functools import partial
from trainer import train_model, load_precomputed_tasks

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
MODEL_DIR = "output/models/nlm_shape"
DATA_DIR = "precomputed_tasks/training"
METRICS_DIR = "output/metrics"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Main function
def main():
    # Load tasks
    print("Loading tasks...")
    tasks = load_precomputed_tasks(DATA_DIR)
    print(f"Loaded {len(tasks)} tasks")
    
    try:
        # Initialize model
        print("Initializing nlm model...")
        model = NLMReasoningModule(
            input_dim=3,
            hidden_dim=128,
            output_dim=11,
            device=device
        )
        model.model = model.model.to(device)
        
        # Train model with task-aware approach
        print("Starting nlm model training...")
        trained_model, history = train_model(
            model=model,
            tasks=tasks,
            num_epochs=6,
            learning_rate=0.5,
            weight_decay=1e-5,
            save_dir=MODEL_DIR,
            model_name="nlm_shape",
            batch_size=8,
            device=device
        )
        
        print("Training complete!")
        return trained_model, history
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None, None

# Run the main function when executed
if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    model, history = main()

In [None]:
import os
import json
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from task import Task, Blackboard
from unified_module import UnifiedReasoningModule
from trainer import load_precomputed_tasks
import random
import traceback
from meta_learning import run_meta_learning, MAMLTrainer, ProtoNetTrainer

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
MODEL_DIR = "output/models/nlm"
TRAIN_DIR = "precomputed_tasks/training_400"
EVAL_DIR = "precomputed_tasks/evaluation_400"
METRICS_DIR = "output/metrics/nlm"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Split tasks into training and validation sets
def split_tasks(tasks, val_ratio=0.2):
    """Split tasks into training and validation sets"""
    random.shuffle(tasks)
    val_size = int(len(tasks) * val_ratio)
    train_tasks = tasks[val_size:]
    val_tasks = tasks[:val_size]
    return train_tasks, val_tasks

# Train meta-learning models
def train_meta_learning(model, train_tasks, val_tasks, method="both"):
    """
    Train meta-learning models.
    
    Args:
        model: Reasoning module instance
        train_tasks: List of training tasks
        val_tasks: List of validation tasks
        method: Which method to use ("maml", "proto", or "both")
        
    Returns:
        Dictionary of trained models
    """
    trained_models = {}
    
    if method.lower() == "maml" or method.lower() == "both":
        print("Training MAML model...")
        maml_trainer = run_meta_learning(
            reasoning_module=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="maml",
            epochs=50,
            lr=0.001,
            weight_decay=1e-5,
            log_dir=os.path.join(MODEL_DIR, "maml")
        )
        trained_models["maml"] = maml_trainer
    
    if method.lower() == "proto" or method.lower() == "both":
        print("Training Prototypical Networks model...")
        proto_trainer = run_meta_learning(
            reasoning_module=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="proto",
            epochs=50,
            lr=0.001,
            weight_decay=1e-5,
            log_dir=os.path.join(MODEL_DIR, "proto")
        )
        trained_models["proto"] = proto_trainer
    
    return trained_models

# Test on new tasks
def test_meta_learning(trainer, test_tasks):
    """
    Test meta-learning models on new tasks.
    
    Args:
        trainer: Meta-learning trainer instance
        test_tasks: List of test tasks
        
    Returns:
        Dictionary of test metrics
    """
    print(f"Testing {trainer.__class__.__name__} on {len(test_tasks)} tasks...")
    
    task_metrics = []
    avg_accuracy = 0.0
    
    for i, task in enumerate(test_tasks):
        try:
            # Prepare support and query sets
            support_graphs = task.train_graphs
            query_graphs = task.test_graphs
            
            if not support_graphs or not query_graphs:
                print(f"Skipping task {i} due to empty graphs")
                continue
                
            # Create batches
            support_batch = Batch.from_data_list(support_graphs).to(device)
            query_batch = Batch.from_data_list(query_graphs).to(device)
            
            # Evaluate with few-shot adaptation
            accuracy = trainer.validate_task(support_batch, query_batch)
            
            print(f"Task {i} accuracy: {accuracy:.4f}")
            task_metrics.append({
                "task_id": task.task_id,
                "accuracy": accuracy,
                "num_train": len(support_graphs),
                "num_test": len(query_graphs)
            })
            
            avg_accuracy += accuracy
            
        except Exception as e:
            print(f"Error testing task {i}: {e}")
            traceback.print_exc()
    
    avg_accuracy /= len(task_metrics) if task_metrics else 1.0
    print(f"Average test accuracy: {avg_accuracy:.4f}")
    
    # Save metrics
    method_name = trainer.__class__.__name__.replace("Trainer", "").lower()
    metrics_path = os.path.join(METRICS_DIR, f"{method_name}_test_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump({
            "avg_accuracy": avg_accuracy,
            "task_metrics": task_metrics
        }, f, indent=2)
    
    return {
        "avg_accuracy": avg_accuracy,
        "task_metrics": task_metrics
    }

# Main function
def main():
    # Load tasks
    print("Loading tasks...")
    tasks = load_precomputed_tasks(TRAIN_DIR)
    test_tasks = load_precomputed_tasks(EVAL_DIR)
    
    # Split into train, validation, and test sets
    train_tasks, val_tasks = split_tasks(tasks, val_ratio=0.2)
    print(f"Split into {len(train_tasks)} training, {len(val_tasks)} validation, and {len(test_tasks)} test tasks")
    
    try:
        # Initialize model
        print("Initializing unified model...")
        model = UnifiedReasoningModule(
            input_dim=3,
            hidden_dim=128,
            output_dim=11,
            device=device
        )
        model.load_complete_state(os.path.join(MODEL_DIR, "unified_reasoning_model_final.pt"))
        model.model = model.model.to(device)
        
        # Train meta-learning models
        print("Training meta-learning models...")
        trained_models = train_meta_learning(
            model=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method="maml"
        )
        
        # Test meta-learning models
        print("Testing meta-learning models...")
        test_results = {}
        for method, trainer in trained_models.items():
            test_results[method] = test_meta_learning(trainer, test_tasks)
        
        # Compare results
        print("\nComparison of meta-learning methods:")
        for method, results in test_results.items():
            print(f"{method.upper()}: Average accuracy = {results['avg_accuracy']:.4f}")
        
        print("Meta-learning training and evaluation complete!")
        return trained_models, test_results
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None, None

# Run the main function when executed
if __name__ == "__main__":
    models, results = main()

In [None]:
import os
import json
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from task4 import Task
from unified_module import UnifiedReasoningModule
from trainer import load_precomputed_tasks
import random
import traceback
from meta_learning import run_meta_learning, MAMLTrainer, ProtoNetTrainer
from sklearn.model_selection import KFold

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
MODEL_DIR = "output/models/unified"
TRAIN_DIR = "precomputed_tasks/training_400"
EVAL_DIR = "precomputed_tasks/evaluation_400"
METRICS_DIR = "output/metrics/unified"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Function to create k-fold splits
def create_kfold_splits(tasks, k=5, random_state=42):
    """Create k-fold splits of tasks for cross-validation"""
    kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
    
    # Create indices for the splits
    task_indices = np.arange(len(tasks))
    splits = []
    
    for train_idx, val_idx in kf.split(task_indices):
        # Get actual tasks for this fold
        fold_train_tasks = [tasks[i] for i in train_idx]
        fold_val_tasks = [tasks[i] for i in val_idx]
        splits.append((fold_train_tasks, fold_val_tasks))
    
    return splits

# Train meta-learning models with k-fold cross-validation
def train_meta_learning_kfold(model_creator, tasks, test_tasks, method="maml", k=5, epochs=30):
    """
    Train meta-learning models using k-fold cross-validation.
    
    Args:
        model_creator: Function to create a fresh model instance
        tasks: List of all training tasks
        test_tasks: List of test tasks for final evaluation
        method: Which method to use ("maml" or "proto")
        k: Number of folds for cross-validation
        epochs: Number of training epochs per fold
        
    Returns:
        Dictionary with performance metrics
    """
    # Create k-fold splits
    splits = create_kfold_splits(tasks, k=k)
    
    # Store metrics for each fold
    fold_metrics = []
    best_model = None
    best_val_acc = 0.0
    
    # Train and evaluate on each fold
    for fold_idx, (fold_train_tasks, fold_val_tasks) in enumerate(splits):
        print(f"\n=== Training Fold {fold_idx+1}/{k} ===")
        print(f"Train tasks: {len(fold_train_tasks)}, Validation tasks: {len(fold_val_tasks)}")
        
        # Create a fresh model instance for this fold
        fold_model = model_creator()
        
        # Train the model on this fold
        try:
            fold_trainer = run_meta_learning(
                reasoning_module=fold_model,
                train_tasks=fold_train_tasks,
                val_tasks=fold_val_tasks,
                method=method,
                epochs=epochs,
                lr=0.001,
                weight_decay=1e-5,
                log_dir=os.path.join(MODEL_DIR, f"{method}_fold{fold_idx+1}")
            )
            
            # Evaluate on validation set
            val_accuracy = evaluate_tasks(fold_trainer, fold_val_tasks)
            
            # Save metrics for this fold
            fold_metrics.append({
                "fold": fold_idx+1,
                "val_accuracy": val_accuracy,
                "train_tasks": len(fold_train_tasks),
                "val_tasks": len(fold_val_tasks),
                "train_losses": fold_trainer.train_losses[-10:],  # Last 10 epochs
                "val_accuracies": fold_trainer.val_accuracies[-10:]  # Last 10 epochs
            })
            
            print(f"Fold {fold_idx+1} validation accuracy: {val_accuracy:.4f}")
            
            # Keep track of best model
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                best_model = fold_trainer
                print(f"New best model with validation accuracy: {val_accuracy:.4f}")
                
        except Exception as e:
            print(f"Error training fold {fold_idx+1}: {e}")
            traceback.print_exc()
            
    # Save best model
    if best_model is not None:
        best_model_path = os.path.join(MODEL_DIR, f"{method}_best_model.pt")
        torch.save({
            'model_state_dict': best_model.reasoning_module.model.state_dict(),
            'model_type': type(best_model.reasoning_module).__name__,
            'meta_learning_method': method,
            'val_accuracy': float(best_val_acc),
            'test_accuracy': float(test_accuracy),
            'train_losses': best_model.train_losses,
            'train_accuracies': best_model.train_accuracies,
            'val_accuracies': best_model.val_accuracies,
        }, best_model_path)
        print(f"Best {method} model saved to {best_model_path}")
    
    # Compute average metrics across folds
    avg_val_accuracy = np.mean([m["val_accuracy"] for m in fold_metrics])
    print(f"\nAverage validation accuracy across {k} folds: {avg_val_accuracy:.4f}")
    
    # Evaluate best model on test set
    if best_model is not None:
        test_accuracy = evaluate_tasks(best_model, test_tasks)
        print(f"Best model test accuracy: {test_accuracy:.4f}")
    else:
        test_accuracy = 0.0
        print("No valid model found for testing")
    
    # Save comprehensive metrics
    metrics = {
        "method": method,
        "avg_val_accuracy": float(avg_val_accuracy),
        "test_accuracy": float(test_accuracy),
        "fold_metrics": fold_metrics,
        "best_fold_val_accuracy": float(best_val_acc),
        "k": k
    }
    
    metrics_path = os.path.join(METRICS_DIR, f"{method}_kfold_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)
    
    return metrics, best_model

# Evaluate model on a set of tasks
def evaluate_tasks(trainer, tasks):
    """
    Evaluate a meta-learning model on a set of tasks.
    
    Args:
        trainer: Meta-learning trainer instance
        tasks: List of tasks to evaluate on
        
    Returns:
        Average accuracy across tasks
    """
    accuracies = []
    
    for task in tqdm(tasks, desc="Evaluating tasks"):
        try:
            # Prepare support and query sets
            support_graphs = task.train_graphs
            query_graphs = task.test_graphs
            
            if not support_graphs or not query_graphs:
                continue
                
            # Create batches
            support_batch = Batch.from_data_list(support_graphs).to(device)
            query_batch = Batch.from_data_list(query_graphs).to(device)
            
            # Evaluate with few-shot adaptation
            accuracy = trainer.validate_task(support_batch, query_batch, task)
            accuracies.append(accuracy)
            
        except Exception as e:
            print(f"Error evaluating task: {e}")
    
    # Calculate average accuracy
    avg_accuracy = np.mean(accuracies) if accuracies else 0.0
    return avg_accuracy

# Main function
def main():
    # Load tasks
    print("Loading tasks...")
    all_train_tasks = load_precomputed_tasks(TRAIN_DIR)
    test_tasks = load_precomputed_tasks(EVAL_DIR)
    print(f"Loaded {len(all_train_tasks)} training tasks and {len(test_tasks)} evaluation tasks")
    
    # Shuffle tasks for better randomization
    random.shuffle(all_train_tasks)
    
    try:
        # Function to create a fresh model instance
        def create_model():
            model = UnifiedReasoningModule(
                input_dim=3,
                hidden_dim=128,
                output_dim=11,
                device=device
            )
            # Load pre-trained weights
            model.load_complete_state(os.path.join(MODEL_DIR, "unified_reasoning_model_final.pt"))
            model.model = model.model.to(device)
            return model
        
        # Initialize the first model
        print("Initializing unified model...")
        base_model = create_model()
        
        # Train MAML with 5-fold cross-validation
        print("\n=== Training MAML with 5-fold cross-validation ===")
        maml_metrics, best_maml = train_meta_learning_kfold(
            model_creator=create_model,
            tasks=all_train_tasks,
            test_tasks=test_tasks,
            method="maml",
            k=5,
            epochs=50
        )
        
        # # Optionally train Prototypical Networks with 5-fold cross-validation
        # print("\n=== Training Prototypical Networks with 5-fold cross-validation ===")
        # proto_metrics, best_proto = train_meta_learning_kfold(
        #     model_creator=create_model,
        #     tasks=all_train_tasks,
        #     test_tasks=test_tasks,
        #     method="proto",
        #     k=5,
        #     epochs=30
        # )
        
        # Compare results
        print("\n=== Final Results ===")
        print(f"MAML: Average validation accuracy = {maml_metrics['avg_val_accuracy']:.4f}, Test accuracy = {maml_metrics['test_accuracy']:.4f}")
        # print(f"ProtoNet: Average validation accuracy = {proto_metrics['avg_val_accuracy']:.4f}, Test accuracy = {proto_metrics['test_accuracy']:.4f}")
        
        # Return best models and metrics
        return {
            "maml": (best_maml, maml_metrics),
            "proto": (best_proto, proto_metrics)
        }
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None

# Run the main function when executed
if __name__ == "__main__":
    results = main()

In [None]:
import os
import json
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from task4 import Task, Blackboard
from unified_module import UnifiedReasoningModule
from nlm_module import NLMReasoningModule
from trainer import load_precomputed_tasks
import random
import traceback
from meta_learning import run_meta_learning, MAMLTrainer, ProtoNetTrainer
from torch_geometric.data import Batch

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
MODEL_DIR = "output/models"
TRAIN_DIR = "precomputed_tasks/training_400"
EVAL_DIR = "precomputed_tasks/evaluation_400"
METRICS_DIR = "output/metrics"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Create k-fold split of tasks
def create_kfold_splits(tasks, k=5):
    """Create k-fold splits of tasks"""
    random.shuffle(tasks)
    fold_size = len(tasks) // k
    folds = []
    
    for i in range(k):
        start_idx = i * fold_size
        end_idx = (i + 1) * fold_size if i < k - 1 else len(tasks)
        val_tasks = tasks[start_idx:end_idx]
        train_tasks = tasks[:start_idx] + tasks[end_idx:]
        folds.append((train_tasks, val_tasks))
    
    return folds

# Initialize a new model for each fold
def initialize_model(model_type, model_path, input_dim=3, hidden_dim=128, output_dim=11):
    """Create a new model instance with fresh weights"""
    if model_type == "unified":
        model = UnifiedReasoningModule(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim,
            device=device
        )
        model.load_complete_state(model_path)
        print("Model weights loaded successfully")
        
    elif model_type == "nlm":
        model = NLMReasoningModule(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim,
            device=device
        )
        model.load_complete_state(model_path)
        print("Model weights loaded successfully")
    else:
        raise ValueError(f"Unknown model type: {model_type}")
        
    model.model = model.model.to(device)
    return model

# Train meta-learning models with k-fold cross-validation
def train_kfold_meta_learning(model_type, model_path, train_dir, test_dir, k=5, method="maml"):
    """
    Train meta-learning models using k-fold cross-validation.
    
    Args:
        model_type: Type of model to use ("unified" or "nlm")
        train_dir: Directory containing training tasks
        test_dir: Directory containing test tasks
        k: Number of folds
        method: Which meta-learning method to use ("maml" or "proto")
        
    Returns:
        Dictionary of trained models and fold metrics
    """
    # Create model-specific directories
    model_dir = os.path.join(MODEL_DIR, model_type)
    metrics_dir = os.path.join(METRICS_DIR, model_type)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Log directory for this specific model-method combination
    method_log_dir = os.path.join(model_dir, method)
    os.makedirs(method_log_dir, exist_ok=True)
    
    # Load all tasks initially just to create the folds
    all_tasks = load_precomputed_tasks(train_dir)
    
    # Create k-fold splits (just the task indices, not the actual task objects)
    all_indices = list(range(len(all_tasks)))
    random.shuffle(all_indices)
    
    fold_size = len(all_indices) // k
    fold_indices = []
    
    for i in range(k):
        start_idx = i * fold_size
        end_idx = (i + 1) * fold_size if i < k - 1 else len(all_indices)
        val_indices = all_indices[start_idx:end_idx]
        train_indices = [idx for idx in all_indices if idx not in val_indices]
        fold_indices.append((train_indices, val_indices))
    
    fold_results = []
    best_model = None
    best_val_accuracy = 0.0
    
    # Train on each fold
    for fold_idx, (train_indices, val_indices) in enumerate(fold_indices):
        print(f"\n--- Training {model_type.upper()} with {method.upper()} on Fold {fold_idx+1}/{k} ---")
        
        # Reload all tasks fresh for this fold
        all_tasks = load_precomputed_tasks(train_dir)
        
        # Split into train and validation based on indices
        train_tasks = [all_tasks[i] for i in train_indices]
        val_tasks = [all_tasks[i] for i in val_indices]
        
        print(f"Train tasks: {len(train_tasks)}, Validation tasks: {len(val_tasks)}")
        
        # Initialize a fresh model for this fold
        model = initialize_model(model_type, model_path)
        
        # Create a specific log directory for this fold
        fold_log_dir = os.path.join(method_log_dir, f"fold{fold_idx+1}")
        os.makedirs(fold_log_dir, exist_ok=True)

        # Before running MAML training, prepare the data properly
        if model_type == "nlm":
            # Process train and validation tasks for NLM
            for task in train_tasks:
                # This ensures the processed graphs have consistent shapes
                if hasattr(model, '_prepare_training_data'):
                    # First make sure the task data is in the right format for NLM
                    for graph in task.train_graphs + task.test_graphs:
                        # Ensure x has the right shape and type for NLM
                        if hasattr(graph, 'x'):
                            if graph.x.dim() == 2 and graph.x.size(1) == 11:
                                graph.x = graph.x.argmax(dim=1)
                            
                            # Convert to long type and reshape if needed
                            graph.x = graph.x.long()
                            if graph.x.dim() == 1:
                                graph.x = graph.x.unsqueeze(1)
                            
                            # Add position info if needed
                            if not hasattr(graph, 'pos') and graph.x.size(1) >= 3:
                                graph.pos = graph.x[:, 1:3].float()
                    
                    # Similar preprocessing for validation tasks
                    for val_task in val_tasks:
                        for graph in val_task.train_graphs + val_task.test_graphs:
                            if hasattr(graph, 'x'):
                                if graph.x.dim() == 2 and graph.x.size(1) == 11:
                                    graph.x = graph.x.argmax(dim=1)
                                graph.x = graph.x.long()
                                if graph.x.dim() == 1:
                                    graph.x = graph.x.unsqueeze(1)
                                if not hasattr(graph, 'pos') and graph.x.size(1) >= 3:
                                    graph.pos = graph.x[:, 1:3].float()
        
        # Train on this fold
        trainer = run_meta_learning(
            reasoning_module=model,
            train_tasks=train_tasks,
            val_tasks=val_tasks,
            method=method,
            epochs=50,
            lr=0.001,
            weight_decay=1e-5,
            log_dir=fold_log_dir
        )
        
        # Evaluate on validation set
        val_accuracy = evaluate_model(trainer, val_tasks, method=method)
        print(f"Fold {fold_idx+1} validation accuracy: {val_accuracy:.4f}")
        
        # Save fold results
        fold_results.append({
            "fold": fold_idx + 1,
            "model_type": model_type,
            "method": method,
            "val_accuracy": val_accuracy,
            "num_train": len(train_tasks),
            "num_val": len(val_tasks)
        })
        
        # Keep track of best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model = trainer
            # Save the best model for this fold
            best_model_path = os.path.join(method_log_dir, f"fold{fold_idx+1}_best.pt")
            torch.save({
                "model_type": model_type,
                "method": method,
                "fold": fold_idx + 1,
                "accuracy": val_accuracy,
                "model_state_dict": model.model.state_dict(),
                "trainer_type": type(trainer).__name__
            }, best_model_path)
            print(f"New best model with validation accuracy: {val_accuracy:.4f}, saved to {best_model_path}")
        
        # Clean up memory
        del train_tasks
        del val_tasks
        del all_tasks
        torch.cuda.empty_cache()  # Clear CUDA cache if using GPU
    
    # Calculate average validation accuracy across folds
    avg_val_accuracy = np.mean([fold["val_accuracy"] for fold in fold_results])
    print(f"\n{model_type.upper()} with {method.upper()}: Average validation accuracy across {k} folds: {avg_val_accuracy:.4f}")
    
    # Load fresh test tasks and evaluate the best model
    test_tasks = load_precomputed_tasks(test_dir)
    print(f"Evaluating best {model_type.upper()} model with {method.upper()} on {len(test_tasks)} test tasks...")
    
    if best_model is not None:
        test_accuracy = evaluate_model(best_model, test_tasks, method=method)
        print(f"Best {model_type.upper()} model with {method.upper()} test accuracy: {test_accuracy:.4f}")
    else:
        test_accuracy = 0.0
        print(f"No best {model_type.upper()} model found")
    
    # Save k-fold metrics
    metrics_path = os.path.join(metrics_dir, f"{method}_metrics.json")
    metrics_data = {
        "model_type": model_type,
        "method": method,
        "avg_val_accuracy": avg_val_accuracy,
        "test_accuracy": test_accuracy,
        "fold_results": fold_results
    }
    
    with open(metrics_path, "w") as f:
        json.dump(metrics_data, f, indent=2)
    print(f"Metrics saved to {metrics_path}")
    
    return {
        "model_type": model_type,
        "method": method,
        "best_model": best_model,
        "avg_val_accuracy": avg_val_accuracy,
        "test_accuracy": test_accuracy,
        "fold_results": fold_results
    }

# Evaluate model on a set of tasks
def evaluate_model(trainer, tasks, method="maml"):
    """
    Evaluate a meta-learning model on a set of tasks.
    
    Args:
        trainer: Meta-learning trainer instance
        tasks: List of tasks to evaluate on
        method: Meta-learning method name
        
    Returns:
        Average accuracy across tasks
    """
    total_accuracy = 0.0
    valid_tasks = 0
    
    for task in tasks:
        try:
            # Prepare support and query sets
            support_graphs = task.train_graphs
            query_graphs = task.test_graphs
            
            if not support_graphs or not query_graphs:
                continue
                
            # Create batches
            support_batch = Batch.from_data_list(support_graphs).to(device)
            query_batch = Batch.from_data_list(query_graphs).to(device)
            
            # Evaluate with few-shot adaptation
            accuracy = trainer.validate_task(support_batch, query_batch, task)
            total_accuracy += accuracy
            valid_tasks += 1
            
        except Exception as e:
            print(f"Error evaluating task {task.task_id}: {e}")
    
    return total_accuracy / valid_tasks if valid_tasks > 0 else 0.0

# Plot k-fold results
def plot_kfold_results(fold_results, model_type, method, save_path=None):
    """Plot k-fold validation results"""
    plt.figure(figsize=(10, 6))
    
    # Extract data
    fold_indices = [fold["fold"] for fold in fold_results]
    val_accuracies = [fold["val_accuracy"] for fold in fold_results]
    
    # Plot validation accuracies
    plt.bar(fold_indices, val_accuracies, alpha=0.7)
    
    # Add average line
    avg_acc = np.mean(val_accuracies)
    plt.axhline(y=avg_acc, color='r', linestyle='--', label=f'Average: {avg_acc:.4f}')
    
    # Labels and title
    plt.xlabel('Fold')
    plt.ylabel('Validation Accuracy')
    plt.title(f'{model_type.upper()} with {method.upper()} K-Fold Cross-Validation Results')
    plt.xticks(fold_indices)
    plt.ylim([0, 1.0])
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path)
        print(f"Plot saved to {save_path}")
    
    plt.show()

# Compare results across different models and methods
def compare_results(results, save_path=None):
    """
    Compare results across different models and methods.
    
    Args:
        results: Dictionary mapping (model_type, method) to result dicts
        save_path: Path to save the comparison plot
    """
    plt.figure(figsize=(12, 8))
    
    # Extract data
    models = []
    val_accs = []
    test_accs = []
    
    for key, result in results.items():
        model_type, method = key
        models.append(f"{model_type}-{method}")
        val_accs.append(result["avg_val_accuracy"])
        test_accs.append(result["test_accuracy"])
    
    # Set positions for bars
    x = np.arange(len(models))
    width = 0.35
    
    # Create grouped bars
    plt.bar(x - width/2, val_accs, width, label='Validation Accuracy', color='skyblue')
    plt.bar(x + width/2, test_accs, width, label='Test Accuracy', color='lightcoral')
    
    # Add details
    plt.xlabel('Model-Method')
    plt.ylabel('Accuracy')
    plt.title('Comparison of Models and Meta-Learning Methods')
    plt.xticks(x, models, rotation=45, ha='right')
    plt.ylim([0, 1.0])
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"Comparison plot saved to {save_path}")
    
    plt.show()

# Main function
def main():
    try:
        print("Starting k-fold cross-validation...")
        all_results = {}
        
        # Run k-fold training for NLM with MAML
        nlm_maml_results = train_kfold_meta_learning(
            model_type="nlm",
            model_path="output/models/nlm/nlm_reasoning_module_final.pt",
            train_dir=TRAIN_DIR,
            test_dir=EVAL_DIR,
            k=5,
            method="maml"
        )
        all_results[("nlm", "maml")] = nlm_maml_results
        
        # Plot NLM-MAML k-fold results
        plot_kfold_results(
            nlm_maml_results["fold_results"], 
            "nlm",
            "maml",
            save_path=os.path.join(METRICS_DIR, "nlm", "maml_results.png")
        )
        
        # # Run k-fold training for NLM with Prototypical Networks
        # nlm_proto_results = train_kfold_meta_learning(
        #     model_type="nlm",
        #     train_dir=TRAIN_DIR,
        #     test_dir=EVAL_DIR,
        #     k=5,
        #     method="proto"
        # )
        # all_results[("nlm", "proto")] = nlm_proto_results
        
        # # Plot NLM-Proto k-fold results
        # plot_kfold_results(
        #     nlm_proto_results["fold_results"], 
        #     "nlm",
        #     "proto",
        #     save_path=os.path.join(METRICS_DIR, "nlm_proto_kfold_results.png")
        # )
        
        # Run k-fold training for Unified with MAML
        unified_maml_results = train_kfold_meta_learning(
            model_type="unified",
            model_path="output/models/unified/unified_reasoning_model_final.pt",
            train_dir=TRAIN_DIR,
            test_dir=EVAL_DIR,
            k=5,
            method="maml"
        )
        all_results[("unified", "maml")] = unified_maml_results
        
        # Plot Unified-MAML k-fold results
        plot_kfold_results(
            unified_maml_results["fold_results"], 
            "unified",
            "maml",
            save_path=os.path.join(METRICS_DIR, "unified", "maml_results.png")
        )
        
        # # Run k-fold training for Unified with Prototypical Networks
        # unified_proto_results = train_kfold_meta_learning(
        #     model_type="unified",
        #     train_dir=TRAIN_DIR,
        #     test_dir=EVAL_DIR,
        #     k=5,
        #     method="proto"
        # )
        # all_results[("unified", "proto")] = unified_proto_results
        
        # # Plot Unified-Proto k-fold results
        # plot_kfold_results(
        #     unified_proto_results["fold_results"], 
        #     "unified",
        #     "proto",
        #     save_path=os.path.join(METRICS_DIR, "unified_proto_kfold_results.png")
        # )
        
        # Compare all results
        compare_results(all_results, save_path=os.path.join(METRICS_DIR, "all_models_comparison.png"))
        
        # Print summary of results
        print("\n----- RESULTS SUMMARY -----")
        for (model_type, method), results in all_results.items():
            print(f"{model_type.upper()} with {method.upper()}: Val Acc = {results['avg_val_accuracy']:.4f}, Test Acc = {results['test_accuracy']:.4f}")
        
        print("\nK-fold meta-learning training and evaluation complete!")
        return all_results
        
    except Exception as e:
        print(f"Error in main function: {e}")
        traceback.print_exc()
        return None

# Run the main function when executed
if __name__ == "__main__":
    results = main()