# Cell 1: Install & Imports

In [None]:
# Install requirements if needed (uncomment if running in a fresh environment)
# !pip install torch numpy matplotlib scipy

from ArV_NeuroSynth import ArV_NeuroSynth
from bci_mat_loader import load_bci_mat_all_runs
import torch
import numpy as np
import matplotlib.pyplot as plt

# Cell 2: Data Loading

In [None]:
# Load EEG data from all runs in the .mat file
dataloader, eeg_tensor = load_bci_mat_all_runs("A01T.mat", num_channels=4, segment_length=256, batch_size=32)
print("EEG tensor shape:", eeg_tensor.shape)

# Cell 3: Model Initialization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ArV_NeuroSynth(input_channels=4, seq_length=256, latent_dim=32, noise_dim=50).to(device)
print("Model initialized on device:", device)

# Cell 4: Training Loop for ArV_NeuroSynth

In [None]:
# Hyperparameters
epochs = 20
learning_rate = 1e-3

# Optimizer (Adam is common for VAE/GANs)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Loss function (MSE for reconstruction, you can add KL or adversarial losses as needed)
recon_loss_fn = torch.nn.MSELoss()

loss_history = []

model.train()
for epoch in range(epochs):
    epoch_loss = 0
    for x_batch, in dataloader:
        x_batch = x_batch.to(device)
        # Forward pass (VAE mode)
        x_recon, mu, logvar, z = model(x_batch, mode='vae')
        # VAE loss: reconstruction + KL divergence
        recon_loss = recon_loss_fn(x_recon, x_batch)
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + 0.01 * kl_loss  # 0.01 is a typical beta value for VAE, adjust as needed

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * x_batch.size(0)

    avg_loss = epoch_loss / len(dataloader.dataset)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.6f}")

# Save the final model and loss curve
save_training_results(model)

# Plot loss curve
plt.figure(figsize=(8, 4))
plt.plot(loss_history, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()

# Cell 5: Save & Load Utilities

In [None]:
def save_training_results(model, x_recon=None, plot=False, prefix="arv_neurosynth"):
    torch.save(model.state_dict(), f"{prefix}_trained.pth")
    print(f"Model saved as {prefix}_trained.pth")
    if x_recon is not None:
        np.save(f"{prefix}_reconstructed_eeg.npy", x_recon.cpu().numpy())
        print(f"Reconstructed EEG saved as {prefix}_reconstructed_eeg.npy")
    if plot and x_recon is not None:
        plt.figure(figsize=(12, 4))
        for ch in range(x_recon.shape[1]):
            plt.plot(x_recon.cpu()[0, ch], label=f"Ch{ch}")
        plt.legend()
        plt.title("Reconstructed EEG")
        plt.grid()
        plt.tight_layout()
        plt.savefig(f"{prefix}_reconstruction_plot.png")
        plt.close()
        print(f"Plot saved as {prefix}_reconstruction_plot.png")

def load_trained_model(model_class, prefix="arv_neurosynth", **model_kwargs):
    model = model_class(**model_kwargs)
    model.load_state_dict(torch.load(f"{prefix}_trained.pth"))
    model.eval()
    print(f"Model loaded from {prefix}_trained.pth")
    return model

# Cell 6: Save Model and Results

In [None]:
# Example: Save after training
with torch.no_grad():
    x_real = next(iter(dataloader))[0][:1].to(device)
    x_recon, _, _, _ = model(x_real, mode='vae')
save_training_results(model, x_recon, plot=True)

# Cell 7: Load Model Later

In [None]:
# Example: Load the trained model for inference or further training
model_loaded = load_trained_model(
    ArV_NeuroSynth,
    input_channels=4, seq_length=256, latent_dim=32, noise_dim=50
)

# Cell 8: Visualize Results

In [None]:
# Visualize real vs reconstructed EEG
plt.figure(figsize=(12, 4))
for ch in range(x_real.shape[1]):
    plt.plot(x_real.cpu()[0, ch], label=f"Real Ch{ch}")
    plt.plot(x_recon.cpu()[0, ch], '--', label=f"Reconstructed Ch{ch}")
plt.legend()
plt.title("Real vs Reconstructed EEG")
plt.grid()
plt.tight_layout()
plt.show()