In [None]:
# demo.ipynb

# Import necessary libraries
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
from scripts.train import train_model  # Import the train_model function from train.py
from scripts.train_cyclemorph import train_model as train_cyclemorph  # Import the train_model function from train_cyclemorph.py

# Define a function to simulate the main.py logic
def run_training(t1_dir, dwi_dir, model_label, epochs, img_size, lr, batch_size, cont_training):
    """
    Simulate the main.py logic for training and plot training curves.
    """
    # GPU configuration
    GPU_iden = 0
    GPU_num = torch.cuda.device_count()
    print('Number of GPU:', GPU_num)
    for GPU_idx in range(GPU_num):
        GPU_name = torch.cuda.get_device_name(GPU_idx)
        print('     GPU #' + str(GPU_idx) + ': ' + GPU_name)
    torch.cuda.set_device(GPU_iden)
    GPU_avai = torch.cuda.is_available()
    print('Currently using:', torch.cuda.get_device_name(GPU_iden))
    print('If the GPU is available?', GPU_avai)

    # Automatically select the training function based on model_label
    if model_label.lower() == "cyclemorph":
        train_function = train_cyclemorph
        print("Using train_cyclemorph for CycleMorph model.")
    else:
        train_function = train_model
        print(f"Using train_model for {model_label} model.")

    # Simulate training process
    print("\nStarting training...")
    train_losses = []
    val_dice_scores = []

    for epoch in range(epochs):
        # Simulate training loss (decreasing over time)
        train_loss = np.exp(-0.1 * epoch) + np.random.normal(0, 0.02)  # Simulated loss
        train_losses.append(train_loss)

        # Simulate validation Dice score (increasing over time)
        val_dice = 1 - np.exp(-0.05 * epoch) + np.random.normal(0, 0.02)  # Simulated Dice score
        val_dice_scores.append(val_dice)

        print(f"Epoch {epoch + 1}/{epochs}: Train Loss = {train_loss:.4f}, Val Dice = {val_dice:.4f}")

    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs + 1), train_losses, label="Training Loss", color="blue")
    plt.plot(range(1, epochs + 1), val_dice_scores, label="Validation Dice Score", color="orange")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.title(f"Training Curves for {model_label}")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
# Simulate command-line arguments using argparse.Namespace
args = argparse.Namespace(
    t1_dir="./data/T1/",  # Replace with your T1 directory
    dwi_dir="./data/DWI/",  # Replace with your DWI directory
    model_label="CycleMorph",  # Try "VoxelMorph" or "CycleMorph"
    epochs=50,  # Number of epochs to simulate
    img_size="128,128,128",
    lr=0.0002,
    batch_size=8,
    cont_training=True
)

# Run the training simulation
run_training(
    t1_dir=args.t1_dir,
    dwi_dir=args.dwi_dir,
    model_label=args.model_label,
    epochs=args.epochs,
    img_size=args.img_size,
    lr=args.lr,
    batch_size=args.batch_size,
    cont_training=args.cont_training
)