In [14]:
import scipy.io
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os

###############################################################################
# 1) LOAD DATA
###############################################################################
file_path = "L23_neuron_20210228_Y54_Z320_test.mat"  # Update with the correct path
mat_data = scipy.io.loadmat(file_path)

# Convert MATLAB arrays to NumPy arrays
eigenface_evoked = np.array(mat_data["Eigenface_0_trials_evoked"])  # shape (500, 1000, 4)
eigenface_isi    = np.array(mat_data["Eigenface_0_trials_isi"])     # shape (500, 1000)
dff_evoked       = np.array(mat_data["dFF0_trials_evoked"])         # shape (229, 1000, 4)
dff_isi          = np.array(mat_data["dFF0_trials_isi"])            # shape (229, 1000)

###############################################################################
# 2) DATASET (PER-NEURON NORMALIZATION)
###############################################################################
class NeuralDataset(Dataset):
    """
    Each sample:
      - x: Concatenation of face(500) + stim(4) => (504,) input dims
      - y: (229,) neural targets
    
    We'll do PER-NEURON normalization on the targets (y).
    """

    def __init__(self, eigenface_evoked, dff_evoked,
                 eigenface_isi, dff_isi,
                 apply_norm=True):
        super().__init__()

        self.samples_x = []
        self.samples_y = []

        n_stim = 4
        # ---------- Evoked -------------
        for c in range(n_stim):
            face_data_cond = eigenface_evoked[:, :, c]   # shape (500, 1000)
            neural_data_cond = dff_evoked[:, :, c]       # shape (229, 1000)
            stim_one_hot = np.zeros((n_stim,), dtype=np.float32)
            stim_one_hot[c] = 1.0

            for col in range(face_data_cond.shape[1]):  # 1000
                face_col = face_data_cond[:, col].astype(np.float32)     # (500,)
                neural_col = neural_data_cond[:, col].astype(np.float32) # (229,)
                face_stim = np.concatenate([face_col, stim_one_hot], axis=0)  # (504,)
                self.samples_x.append(face_stim)
                self.samples_y.append(neural_col)

        # ---------- ISI (no stim) -------------
        face_isi  = eigenface_isi  # (500, 1000)
        neural_isi = dff_isi       # (229, 1000)
        zero_stim = np.zeros((n_stim,), dtype=np.float32)

        for col in range(face_isi.shape[1]):  # 1000
            face_col = face_isi[:, col].astype(np.float32)
            neural_col = neural_isi[:, col].astype(np.float32)
            face_stim = np.concatenate([face_col, zero_stim], axis=0)
            self.samples_x.append(face_stim)
            self.samples_y.append(neural_col)

        # Convert to Tensors
        self.samples_x = torch.tensor(np.array(self.samples_x))  # shape (N, 504)
        self.samples_y = torch.tensor(np.array(self.samples_y))  # shape (N, 229)

        # If we want to apply per-neuron normalization:
        if apply_norm:
            # Compute mean/std for each neuron
            self.means = self.samples_y.mean(dim=0)    # (229,)
            self.stds  = self.samples_y.std(dim=0)     # (229,)
            # Avoid division by zero
            self.stds  = torch.where(self.stds < 1e-9, torch.ones_like(self.stds), self.stds)

            # Normalize each neuron individually
            self.samples_y = (self.samples_y - self.means) / self.stds
        else:
            # No normalization
            self.means = torch.zeros((229,), dtype=torch.float32)
            self.stds  = torch.ones((229,), dtype=torch.float32)

    def __len__(self):
        return self.samples_x.shape[0]

    def __getitem__(self, idx):
        return self.samples_x[idx], self.samples_y[idx]


###############################################################################
# 3) SINGLE-BRANCH BIG MLP
###############################################################################
class BigMLP(nn.Module):
    """
    A large MLP to encourage overfitting on the training data.
    Input: 504-dim (500 face + 4 stim)
    Output: 229-dim (neural response, normalized)
    """

    def __init__(self, input_dim=504, hidden_dims=[2048, 1024, 512], output_dim=229):
        super().__init__()

        layers = []
        prev_dim = input_dim
        for hdim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hdim))
            layers.append(nn.ReLU())
            prev_dim = hdim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


###############################################################################
# 4) HELPER FUNCTIONS
###############################################################################
def invert_normalization(y_normed, means, stds):
    """
    Invert the normalization done on the neural outputs.
    y_normed: (batch, 229) normalized
    means, stds: (229,)
    """
    return y_normed * stds + means

def compute_r2(all_preds, all_true):
    """
    Compute R^2 across all samples and all neurons combined (flattened)
    OR you can do it per neuron if desired.
    """
    # Flatten out everything: shape (N*229,)
    y_true = all_true.view(-1).cpu().numpy()
    y_pred = all_preds.view(-1).cpu().numpy()
    
    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    if ss_tot < 1e-12:
        return 0.0
    return 1.0 - ss_res / ss_tot


###############################################################################
# 5) TRAINING LOOP (FOCUSED ON TRAINING ONLY)
###############################################################################
def train_for_overfitting(eigenface_evoked, dff_evoked,
                          eigenface_isi, dff_isi,
                          epochs=300,
                          batch_size=64,
                          lr=1e-3,
                          device='cuda'):

    # Create the full dataset. We do NOT split into val/test
    dataset = NeuralDataset(eigenface_evoked, dff_evoked,
                            eigenface_isi, dff_isi,
                            apply_norm=True)
    means, stds = dataset.means, dataset.stds
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

    # Create the model
    model = BigMLP(input_dim=504, hidden_dims=[2048, 1024, 512], output_dim=229).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)  # normalized
            pred = model(batch_x)
            loss = criterion(pred, batch_y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        train_loss = total_loss / len(train_loader)

        # Compute R^2 on the entire training set (optional, but nice to see progress)
        # We'll do it in a quick pass:
        model.eval()
        all_preds, all_true = [], []
        with torch.no_grad():
            for bx, by in train_loader:
                bx = bx.to(device)
                by = by.to(device)
                py = model(bx)
                all_preds.append(py)
                all_true.append(by)
        
        # Concatenate
        all_preds = torch.cat(all_preds, dim=0)
        all_true  = torch.cat(all_true, dim=0)
        r2_train  = compute_r2(all_preds, all_true)

        print(f"[Epoch {epoch+1}/{epochs}] Train Loss={train_loss:.6f}, Train R^2={r2_train:.4f}")

    # Final model + dataset stats
    return model, (means, stds), dataset


###############################################################################
# 6) INFERENCE EXAMPLE
###############################################################################
def predict_entire_dataset(model, dataset, means, stds, device='cuda', batch_size=64):
    """
    Return unnormalized predictions for the entire dataset.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model.eval()

    all_preds = []
    with torch.no_grad():
        for bx, _ in loader:
            bx = bx.to(device)
            py_norm = model(bx)  # normalized
            # invert normalization
            py_real = invert_normalization(py_norm, means.to(device), stds.to(device))
            all_preds.append(py_real.cpu())
    return torch.cat(all_preds, dim=0).numpy()


###############################################################################
# 7) MAIN
###############################################################################
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Train only on the entire dataset (no validation).
    model, (means, stds), train_dataset = train_for_overfitting(
        eigenface_evoked, dff_evoked,
        eigenface_isi, dff_isi,
        epochs=300,    # Increase if you still want to push training R^2 even higher
        batch_size=64,
        lr=1e-3,
        device=device
    )

    # After training, you can get final predictions on the training set
    train_preds_unnorm = predict_entire_dataset(model, train_dataset, means, stds, device=device)
    print("Shape of train_preds_unnorm:", train_preds_unnorm.shape)  # (N, 229)
    
    # ...and do whatever analysis/plots you like on train_preds_unnorm.
    # Since we don't care about test/validation, we won't show that here.


[Epoch 1/300] Train Loss=1.054726, Train R^2=-0.0004
[Epoch 2/300] Train Loss=0.999412, Train R^2=-0.0002
[Epoch 3/300] Train Loss=1.000072, Train R^2=-0.0001
[Epoch 4/300] Train Loss=0.999426, Train R^2=-0.0000
[Epoch 5/300] Train Loss=0.999456, Train R^2=-0.0000
[Epoch 6/300] Train Loss=0.999902, Train R^2=-0.0000
[Epoch 7/300] Train Loss=1.001040, Train R^2=-0.0000
[Epoch 8/300] Train Loss=1.000236, Train R^2=-0.0000
[Epoch 9/300] Train Loss=1.000502, Train R^2=-0.0000
[Epoch 10/300] Train Loss=1.000273, Train R^2=-0.0000
[Epoch 11/300] Train Loss=1.000115, Train R^2=-0.0000
[Epoch 12/300] Train Loss=0.999089, Train R^2=-0.0000
[Epoch 13/300] Train Loss=1.001155, Train R^2=-0.0000
[Epoch 14/300] Train Loss=1.000389, Train R^2=-0.0000
[Epoch 15/300] Train Loss=1.000396, Train R^2=-0.0000
[Epoch 16/300] Train Loss=1.000389, Train R^2=-0.0000
[Epoch 17/300] Train Loss=1.000117, Train R^2=-0.0000
[Epoch 18/300] Train Loss=0.999266, Train R^2=-0.0000
[Epoch 19/300] Train Loss=0.999763, T