In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import scipy.io

###############################################################################
# 1) LOAD DATA
###############################################################################
file_path = "L23_neuron_20210228_Y54_Z320_test.mat"  # Update with the correct path
mat_data = scipy.io.loadmat(file_path)

# Convert MATLAB arrays to NumPy arrays
eigenface_evoked = np.array(mat_data["Eigenface_0_trials_evoked"])  # (500, 1000, 4)
eigenface_isi    = np.array(mat_data["Eigenface_0_trials_isi"])     # (500, 1000)
dff_evoked       = np.array(mat_data["dFF0_trials_evoked"])         # (229, 1000, 4)
dff_isi          = np.array(mat_data["dFF0_trials_isi"])            # (229, 1000)


###############################################################################
# 2) DATASET WITH PER-NEURON NORMALIZATION
###############################################################################
class NeuralDataset(Dataset):
    """
    Each column => x(504) = [face(500), stim(4)], y(229).
    Per-neuron norm => y is normalized so each neuron is ~N(0,1).
    """

    def __init__(self, eigenface_evoked, dff_evoked,
                 eigenface_isi, dff_isi,
                 apply_norm=True):
        super().__init__()

        self.samples_x = []
        self.samples_y = []

        face_dim = 500
        n_stim = 4

        # Evoked
        for c in range(n_stim):
            face_block   = eigenface_evoked[:, :, c]  # (500, 1000)
            neural_block = dff_evoked[:, :, c]        # (229, 1000)
            stim_onehot = np.zeros((n_stim,), dtype=np.float32)
            stim_onehot[c] = 1.0
            for col in range(face_block.shape[1]):
                face_col = face_block[:, col].astype(np.float32)
                neural_col = neural_block[:, col].astype(np.float32)
                x_in = np.concatenate([face_col, stim_onehot], axis=0)  # (504,)
                self.samples_x.append(x_in)
                self.samples_y.append(neural_col)

        # ISI => zero stim
        face_isi_block   = eigenface_isi
        neural_isi_block = dff_isi
        zero_stim = np.zeros((n_stim,), dtype=np.float32)
        for col in range(face_isi_block.shape[1]):
            face_col = face_isi_block[:, col].astype(np.float32)
            neural_col = neural_isi_block[:, col].astype(np.float32)
            x_in = np.concatenate([face_col, zero_stim], axis=0)
            self.samples_x.append(x_in)
            self.samples_y.append(neural_col)

        self.samples_x = torch.tensor(np.array(self.samples_x))  # (N,504)
        self.samples_y = torch.tensor(np.array(self.samples_y))  # (N,229)

        if apply_norm:
            self.means = self.samples_y.mean(dim=0)   # (229,)
            self.stds  = self.samples_y.std(dim=0)    # (229,)
            self.stds  = torch.where(self.stds<1e-9, torch.ones_like(self.stds), self.stds)
            self.samples_y= (self.samples_y - self.means)/ self.stds
        else:
            self.means = torch.zeros((229,), dtype=torch.float32)
            self.stds  = torch.ones((229,), dtype=torch.float32)

    def __len__(self):
        return len(self.samples_x)

    def __getitem__(self, idx):
        return self.samples_x[idx], self.samples_y[idx]


###############################################################################
# 3) MULTI-BRANCH MLP
###############################################################################
class MultiBranchMLP(nn.Module):
    def __init__(self,
                 face_dim=500, stim_dim=4,
                 hidden_face=[1024,512,256],
                 hidden_stim=[128,64],
                 hidden_fuse=[256,128],
                 output_dim=229):
        super().__init__()

        # face branch
        face_layers = []
        in_dim = face_dim
        for hdim in hidden_face:
            face_layers.append(nn.Linear(in_dim, hdim))
            face_layers.append(nn.ReLU())
            in_dim = hdim
        self.face_net = nn.Sequential(*face_layers)

        # stim branch
        stim_layers = []
        in_dim_s= stim_dim
        for hdim_s in hidden_stim:
            stim_layers.append(nn.Linear(in_dim_s, hdim_s))
            stim_layers.append(nn.ReLU())
            in_dim_s= hdim_s
        self.stim_net= nn.Sequential(*stim_layers)

        # fuse
        fuse_in= hidden_face[-1] + hidden_stim[-1]
        fuse_seq= []
        prev_dim= fuse_in
        for fdim in hidden_fuse:
            fuse_seq.append(nn.Linear(prev_dim, fdim))
            fuse_seq.append(nn.ReLU())
            prev_dim= fdim
        fuse_seq.append(nn.Linear(prev_dim, output_dim))
        self.fuse= nn.Sequential(*fuse_seq)

    def forward(self, x):
        face_part= x[:, :500]
        stim_part= x[:, 500:]
        face_out= self.face_net(face_part)
        stim_out= self.stim_net(stim_part)
        combined= torch.cat([face_out, stim_out], dim=-1)
        out= self.fuse(combined)  # (B,229)
        return out


###############################################################################
# 4) COMBINED LOSS
###############################################################################
def correlation_loss(y_pred, y_true, eps=1e-8):
    y_pred_f= y_pred.view(-1)
    y_true_f= y_true.view(-1)
    pm= y_pred_f.mean()
    tm= y_true_f.mean()
    cov= ((y_pred_f-pm)*(y_true_f-tm)).sum()
    var1= ((y_pred_f-pm)**2).sum()+eps
    var2= ((y_true_f-tm)**2).sum()+eps
    corr= cov/(var1.sqrt()*var2.sqrt())
    return 1.0- corr

def combined_loss(y_pred, y_true, alpha=0.5):
    mse_v= ((y_pred- y_true)**2).mean()
    corr_v= correlation_loss(y_pred, y_true)
    return alpha*mse_v+ (1-alpha)*corr_v


###############################################################################
# 5) R^2 HELPER
###############################################################################
def compute_r2(model, loader, ds, device='cuda'):
    """
    We'll pass the entire loader to the model, invert normalization, 
    and compute average R^2 across all neurons.
    """
    model.eval()
    all_preds= []
    all_true= []
    with torch.no_grad():
        for x_batch,y_batch_norm in loader:
            x_batch= x_batch.to(device)
            y_batch_norm= y_batch_norm.to(device)
            preds_norm= model(x_batch)  # shape(B,229)
            # invert
            means= ds.means.to(device)   # shape(229,)
            stds= ds.stds.to(device)
            preds_raw= preds_norm*stds + means
            true_raw= y_batch_norm*stds + means
            all_preds.append(preds_raw.cpu().numpy())
            all_true.append(true_raw.cpu().numpy())
    all_preds= np.concatenate(all_preds, axis=0) # (N,229)
    all_true= np.concatenate(all_true, axis=0)
    # measure average R^2
    n_neurons= all_preds.shape[1]
    r2_arr= []
    for n in range(n_neurons):
        y_t= all_true[:,n]
        y_p= all_preds[:,n]
        ss_res= np.sum((y_t-y_p)**2)
        ss_tot= np.sum((y_t- np.mean(y_t))**2)
        if ss_tot>1e-12:
            r2_arr.append(1- ss_res/ ss_tot)
        else:
            r2_arr.append(0.0)
    return np.mean(r2_arr)


###############################################################################
# 6) TRAIN FUNCTION
###############################################################################
def train_model(eigenface_evoked, dff_evoked,
                eigenface_isi, dff_isi,
                epochs=100, batch_size=64, lr=1e-3,
                alpha=0.5,
                device='cuda'):
    
    ds_full= NeuralDataset(eigenface_evoked, dff_evoked, eigenface_isi, dff_isi,
                           apply_norm=True)
    N= len(ds_full)
    val_size= int(0.2*N)
    train_size= N- val_size
    train_ds, val_ds= random_split(ds_full,[train_size, val_size])
    train_loader= DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader= DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    model= MultiBranchMLP(
        face_dim=500, stim_dim=4,
        hidden_face=[1024,512,256],
        hidden_stim=[128,64],
        hidden_fuse=[256,128],
        output_dim=229
    ).to(device)

    optimizer= optim.Adam(model.parameters(), lr=lr)
    scheduler= optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                    factor=0.5, patience=5, verbose=True)

    best_val_loss= float('inf')
    best_state= None
    patience=0
    max_patience=50

    for epoch in range(epochs):
        # 1) Train
        model.train()
        total_loss= 0.0
        for x_batch,y_batch in train_loader:
            x_batch= x_batch.to(device)
            y_batch= y_batch.to(device)
            pred= model(x_batch)
            loss= combined_loss(pred,y_batch, alpha=alpha)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss+= loss.item()
        train_loss= total_loss/ len(train_loader)

        # 2) Validation
        model.eval()
        val_loss_sum= 0.0
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val= x_val.to(device)
                y_val= y_val.to(device)
                preds_val= model(x_val)
                lv= combined_loss(preds_val,y_val, alpha=alpha)
                val_loss_sum+= lv.item()
        val_loss= val_loss_sum/ len(val_loader)

        scheduler.step(val_loss)

        # 3) Only every 20 epochs => compute train R^2, val R^2
        #   or last epoch => or epoch==epochs-1
        do_print= (epoch%20==0 or epoch== epochs-1)
        if do_print:
            train_r2= compute_r2(model, train_loader, ds_full, device=device)
            val_r2  = compute_r2(model, val_loader, ds_full, device=device)
            print(f"Epoch {epoch+1}/{epochs}, train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, "
                  f"train_R^2={train_r2:.4f}, val_R^2={val_r2:.4f}")
            

        # 4) Early stopping
        if val_loss< best_val_loss -1e-9:
            best_val_loss= val_loss
            best_state= model.state_dict()
            patience=0
        else:
            patience+=1
            if patience>= max_patience:
                print("Early stopping triggered.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    print(f"Training done. Best val loss= {best_val_loss:.5f}")

    # Return the model + the dataset
    return model, ds_full



###############################################################################
# 8) MAIN
###############################################################################
if __name__=="__main__":
    device= 'cuda' if torch.cuda.is_available() else 'cpu'

    model, ds_full= train_model(
        eigenface_evoked, dff_evoked,
        eigenface_isi, dff_isi,
        epochs=150,
        batch_size=64,
        lr=1e-3,
        alpha=0.5,
        device=device
    )


FileNotFoundError: [Errno 2] No such file or directory: 'L23_neuron_20210228_Y54_Z320_test.mat'

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def invert_normalization(pred_norm, ds):
    """
    Convert predicted normalized (N,229) back to raw domain using ds.means, ds.stds.
    """
    device = pred_norm.device
    means = ds.means.to(device)   # shape(229,)
    stds  = ds.stds.to(device)
    return pred_norm*stds + means

def analyze_face_stim_orthogonality(model, ds_full, 
                                    M=100, K=10,
                                    sub_pcs=3,
                                    device='cuda',
                                    random_seed=None,
                                    null_hypothesis_mode=False):
    """
    1) Baseline => x_base=0 => model => y_base
    2) Face-driven => keep face from real sample, set stim=0
    3) Stim-driven => keep stim from real sample, set face=0
    4) Possibly do a 'null hypothesis' mode: e.g. face vs. face or random scramble
    5) measure subspace overlap
    6) 2D PCA plot
    """

    model.eval()
    if random_seed is not None:
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

    # Baseline
    x_base = torch.zeros(1,504, dtype=torch.float32).to(device)
    with torch.no_grad():
        y_base_norm= model(x_base)[0]  # shape(229,)

    N= len(ds_full)
    
    if not null_hypothesis_mode:
        # normal scenario => face vs. stim
        idxs= np.random.choice(N, M+K, replace=False)
        face_idxs= idxs[:M]
        stim_idxs= idxs[M:]

        # face-driven
        x_face_list= []
        for i in face_idxs:
            x_raw,_= ds_full[i]
            x_mod= x_raw.clone()
            # zero out stim => last 4
            x_mod[500:]= 0.0
            x_face_list.append(x_mod)
        x_face_all= torch.stack(x_face_list, dim=0).to(device)
        with torch.no_grad():
            y_face_norm= model(x_face_all)
        y_face_norm= y_face_norm - y_base_norm

        # stim-driven
        x_stim_list= []
        for j in stim_idxs:
            x_raw,_= ds_full[j]
            x_mod= x_raw.clone()
            # zero out face => first 500
            x_mod[:500]=0.0
            x_stim_list.append(x_mod)
        x_stim_all= torch.stack(x_stim_list, dim=0).to(device)
        with torch.no_grad():
            y_stim_norm= model(x_stim_all)
        y_stim_norm= y_stim_norm- y_base_norm

        # invert to raw
        y_face_raw= invert_normalization(y_face_norm, ds_full)
        y_stim_raw= invert_normalization(y_stim_norm, ds_full)

        face_np= y_face_raw.cpu().numpy()
        stim_np= y_stim_raw.cpu().numpy()

        label_face= 'Face-driven'
        label_stim= 'Stim-driven'

    else:
        # Null hypothesis => let's do face vs face, i.e. two random subsets of face columns
        # so we can see how "angle" might appear if there's actually no difference
        # We'll do M, K from the dataset but both sets are "face=some real face, stim=0"
        idxs= np.random.choice(N, M+K, replace=False)
        setA_idxs= idxs[:M]
        setB_idxs= idxs[M:]

        xA_list= []
        for i in setA_idxs:
            x_raw,_= ds_full[i]
            x_mod= x_raw.clone()
            x_mod[500:]= 0.0
            xA_list.append(x_mod)
        xA_all= torch.stack(xA_list, dim=0).to(device)
        with torch.no_grad():
            yA_norm= model(xA_all)
        yA_norm= yA_norm- y_base_norm
        yA_raw= invert_normalization(yA_norm, ds_full)
        setA_np= yA_raw.cpu().numpy()

        xB_list= []
        for j in setB_idxs:
            x_raw,_= ds_full[j]
            x_mod= x_raw.clone()
            x_mod[500:]= 0.0
            xB_list.append(x_mod)
        xB_all= torch.stack(xB_list, dim=0).to(device)
        with torch.no_grad():
            yB_norm= model(xB_all)
        yB_norm= yB_norm- y_base_norm
        yB_raw= invert_normalization(yB_norm, ds_full)
        setB_np= yB_raw.cpu().numpy()

        face_np= setA_np
        stim_np= setB_np
        label_face= 'Set A (face/fake)'
        label_stim= 'Set B (face/fake)'

    # measure subspace overlap
    from sklearn.decomposition import PCA
    pca_face= PCA(n_components=sub_pcs).fit(face_np)
    U_face= pca_face.components_.T
    pca_stim= PCA(n_components=sub_pcs).fit(stim_np)
    U_stim= pca_stim.components_.T

    overlap= np.linalg.norm(U_face.T@ U_stim, 'fro')**2
    print(f"Subspace overlap (top {sub_pcs} PCs) = {overlap:.4f}")

    # angle top1
    face_pc1= U_face[:,0]
    stim_pc1= U_stim[:,0]
    dot= np.dot(face_pc1, stim_pc1)
    denom= np.linalg.norm(face_pc1)* np.linalg.norm(stim_pc1)+1e-12
    angle_deg= np.degrees(np.arccos(dot/ denom))
    print(f"Angle between top1 PC: {angle_deg:.2f} deg")

    # 2D PCA
    all_data= np.concatenate([face_np, stim_np], axis=0)
    pca_2d= PCA(n_components=2)
    all_2d= pca_2d.fit_transform(all_data)
    M_= face_np.shape[0]
    plt.figure(figsize=(6,5))
    plt.scatter(all_2d[:M_,0], all_2d[:M_,1], c='blue', alpha=0.6, label=label_face)
    plt.scatter(all_2d[M_:,0], all_2d[M_:,1], c='red', alpha=0.6, label=label_stim)
    plt.title("Face vs. Stim subspace" if not null_hypothesis_mode else "Null Hypothesis (face vs. face)")
    plt.legend()
    plt.show()


In [None]:

# Real face vs. stim:
torch.manual_seed(0)
analyze_face_stim_orthogonality(model, ds_full, M=200, K=300,
                                device=device,
                                random_seed=111,
                                null_hypothesis_mode=False)

# Null hypothesis => face vs. face
analyze_face_stim_orthogonality(model, ds_full, M=200, K=300,
                                device=device,
                                random_seed=111,
                                null_hypothesis_mode=True)



In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

###############################################################################
# 1) Helper: invert MLP’s per-neuron normalization to raw neural space
###############################################################################
def invert_normalization(y_norm, means, stds):
    """
    y_norm: shape(..., 229) in normalized space
    means, stds: shape(229,)
    """
    return y_norm* stds + means

###############################################################################
# 2) Build Stimulus Subspace R
###############################################################################
def build_stimulus_subspace(model, dataset, device='cuda'):
    """
    We'll create an (N=229, S) matrix R of trial-averaged predicted responses 
    for each of the S stimuli in the dataset. 
    For example, if we have 4 stimuli, S=4 => shape(229,4).

    Steps:
      - Identify which columns in dataset belong to each stimulus c in [0..3].
      - Predict with the MLP -> shape(229) for each sample, invert normalization 
        to get raw neural domain.
      - Average across those samples. => col c => R[:, c]
    Returns R in shape(229, S).
    """
    model.eval()
    n_stim = 4  # or however many stimuli
    # We'll accumulate lists: for each c in [0..3], store predicted raw neural
    preds_per_stim = [[] for _ in range(n_stim)]
    
    # We have means, std in dataset
    means = dataset.means.cpu().numpy()   # shape(229,)
    stds  = dataset.stds.cpu().numpy()

    loader = DataLoader(dataset, batch_size=128, shuffle=False)
    with torch.no_grad():
        for x_batch, y_batch_norm in loader:
            # x_batch shape(B,504)
            # we see the last 4 dims to find the index of the 1-hot stim? 
            # Or if it's a "zero-stim" for ISI. 
            # We'll do argmax or something to identify c
            x_batch = x_batch.to(device)
            preds_norm = model(x_batch)  # shape(B,229) in normalized domain
            # invert to raw
            preds_raw = preds_norm* torch.tensor(stds, device=device) \
                        + torch.tensor(means, device=device)

            x_batch_cpu = x_batch.cpu().numpy()
            preds_raw_cpu = preds_raw.cpu().numpy()  # shape(B,229)

            for i in range(x_batch_cpu.shape[0]):
                # identify which stimulus
                stim_vec = x_batch_cpu[i, 500:]  # shape(4,)
                c_idx = np.argmax(stim_vec)  # either 0..3 or 0 if all zero => ISI
                # if sum(stim_vec)==0 => no stim => we define c_idx=some code or treat it as 4th? 
                # We'll do if np.allclose(stim_vec,0): c_idx= -1 => "ISI"
                # but for simplicity, we do only c_idx in [0..3], 
                # or skip if it's "ISI"? up to you
                if np.allclose(stim_vec,0):
                    # no stim => e.g. treat as c_idx= -1 or skip
                    # We'll skip from R if you don't want to count it as a "stim" 
                    continue
                else:
                    preds_per_stim[c_idx].append(preds_raw_cpu[i])  # shape(229,)

    # Now average for each c
    # We'll build R => shape(229, S)
    R_list = []
    for c in range(n_stim):
        if len(preds_per_stim[c])>0:
            arr_c = np.stack(preds_per_stim[c], axis=0)  # shape(#samples_c,229)
            mean_c= arr_c.mean(axis=0)                   # shape(229,)
            R_list.append(mean_c)
        else:
            # no samples => put 0 vector
            R_list.append(np.zeros((229,), dtype=np.float32))

    # shape => (S,229), we transpose => (229,S)
    R_stim = np.stack(R_list, axis=1)  # shape(229, n_stim)
    return R_stim


###############################################################################
# 3) Reduced-Rank Regression => from Face(PC) to MLP's predicted neural activity
###############################################################################
def do_RRR_face_to_neural(model, dataset, rank=32, device='cuda'):
    """
    We'll do a simplified approach:
      - gather face data + predicted neural from the entire dataset
      - reduce face to e.g. 32 PCs
      - do a low-rank regression => Y = A X, with rank <= rank => we factor out => E_B x F_B^T
    Returns E_B => shape(229, rank) as the "behavior subspace."
    """
    # 1) collect big X => shape(Nsamples, 500) for face, Y => shape(Nsamples, 229) in raw domain 
    #    We'll do same approach: x[:500], x[500:] => which might hold the one-hot stim. We only want face
    #    Then predict or take the MLP predicted? Actually, for RRR we want to do it from face => real neural or MLP predicted neural? 
    #    The snippet used real neural. But let's replicate with MLP predicted neural?

    # We'll gather face + neural pred in raw domain
    # For a real approach, you'd gather face + real neural data. 
    # Here we do face + MLP predicted? Let's do face + MLP predicted for demonstration.

    means = dataset.means.to(device)  # shape(229,)
    stds  = dataset.stds.to(device)

    loader = DataLoader(dataset, batch_size=128, shuffle=False)
    face_data_list = []
    neural_pred_list= []
    model.eval()
    with torch.no_grad():
        for x_batch, y_batch_norm in loader:
            x_batch= x_batch.to(device)
            # face => x_batch[:, :500]
            face_data_list.append(x_batch[:, :500].cpu().numpy())
            # neural pred => invert raw
            preds_norm= model(x_batch)
            preds_raw= preds_norm* stds + means
            neural_pred_list.append(preds_raw.cpu().numpy())

    face_all = np.concatenate(face_data_list, axis=0)   # shape(Nsamples,500)
    neural_all= np.concatenate(neural_pred_list, axis=0)# shape(Nsamples,229)

    # 2) reduce face by PCA => say keep 32 PCs if 500 is large
    # do PCA => face (Nsamples, 500)
    keep_f= 32
    from sklearn.decomposition import PCA
    pca_face= PCA(n_components=keep_f)
    face_all_pcs = pca_face.fit_transform(face_all) # shape(Nsamples, keep_f)

    # 3) do a linear regression => Y ~ A * (face_all_pcs)
    # then we do a rank-limited approach => a SVD on the best fit matrix or partial least squares
    # We'll do the direct approach: best-fit = Y X^+
    # shape => Y: (Nsamples,229), X: (Nsamples, keep_f)
    # => A => shape(229, keep_f)
    # then we'll do an SVD and keep top=rank
    X_ = face_all_pcs  # shape(Nsamples, keep_f)
    Y_ = neural_all    # shape(Nsamples, 229)
    # solve => A = Y_+.T X_ => we do pseudo-inverse or normal eq
    # normal eq => A => (X^T X)^-1 X^T Y => shape(keep_f,229)
    # we want A^T => shape(229, keep_f)
    XtX= X_.T@ X_
    XtY= X_.T@ Y_   # shape(keep_f, 229)
    A_ = np.linalg.pinv(XtX)@ XtY   # shape(keep_f,229)
    A_ = A_.T  # shape(229, keep_f)

    # 4) do SVD => A_ => shape(229, keep_f). We keep top=rank columns => E_B => shape(229, rank)
    Ua, Sa, VaT= np.linalg.svd(A_, full_matrices=False)  # A_= U diag S VaT
    # shape(Ua)= (229,  keep_f), shape(Sa)= (min(229, keep_f)), shape(VaT)= (keep_f, keep_f)
    r_ = min(rank, keep_f, 229)
    E_B = Ua[:, :r_]  # shape(229, r_)
    # This E_B is the "behavior subspace"
    return E_B


###############################################################################
# 4) SVD approach to find shared dimension, then “stim-only,” “beh-only,” etc.
###############################################################################
def find_shared_and_measure_variance(R, E_B, test_data, sub_dims=32):
    """
    R shape(229, S) => Stim subspace (S up to 32 stimuli).
    E_B shape(229, r) => Behavior subspace from RRR
    We'll do SVD(E_B^T R).
    Then measure how much variance in test_data is in the shared dimension, etc.

    test_data => shape(229, T). We'll measure variance in subspaces:
      - shared_dim => top left singular vector in N-dim
      - stim-only => R after projecting out shared
      - beh-only => E_B after projecting out shared
    """

    # 1) Possibly orthonormalize R, E_B
    # if S < sub_dims, we just keep them as is, or do a small SVD. Let's do a QR on each for an orthonormal basis
    Qr, _= np.linalg.qr(R)   # shape(229, S)
    Qb, _= np.linalg.qr(E_B) # shape(229, r)
    
    # 2) M= Qb^T Qr => shape(r, S)
    M = Qb.T@ Qr
    U_, s_, Vt_ = np.linalg.svd(M, full_matrices=False)
    # top singular vector => U_[:,0] in r-dim => the direction in Qb
    # => shared_dim in N-dim => Qb @ U_[:,0]
    u_b = U_[:,0]
    shared_dim= Qb @ u_b  # shape(229,)
    # normalize
    norm_sh= np.linalg.norm(shared_dim)
    if norm_sh>1e-12:
        shared_dim/= norm_sh

    # 3) stim-only => project out shared_dim from Qr
    def project_out(subspace, vec):
        proj= vec.reshape(-1,1)@(vec.reshape(1,-1)@ subspace)
        return subspace- proj
    Qr_wo = project_out(Qr, shared_dim)
    # re-orthonormalize
    Qr_wo, _= np.linalg.qr(Qr_wo)

    # 4) beh-only => project out shared_dim from Qb
    Qb_wo = project_out(Qb, shared_dim)
    Qb_wo, _= np.linalg.qr(Qb_wo)

    # measure variance
    # test_data => shape(229, T)
    def project_var(data, basis):
        coords= basis.T@ data  # shape(d, T)
        return np.sum(coords**2)/ data.shape[1]  # average

    var_shared= project_var(test_data, shared_dim.reshape(-1,1))
    var_stim_only= project_var(test_data, Qr_wo)
    var_beh_only= project_var(test_data, Qb_wo)
    total_var= np.sum(test_data**2)/ test_data.shape[1]

    return var_shared, var_stim_only, var_beh_only, total_var, shared_dim


###############################################################################
# 5) Putting it All Together
###############################################################################
def full_subspace_analysis(model, dataset, device='cuda'):
    """
    1) Build R => shape(229, S) from the MLP's predictions, averaging per stimulus.
    2) Build E_B => from RRR of face => MLP predicted neural.
    3) SVD => find shared dimension, measure variance in test data
    4) Print or plot results
    """
    # 1) Stim subspace
    R_stim= build_stimulus_subspace(model, dataset, device=device)  # shape(229, S)
    print("R_stim shape:", R_stim.shape)

    # 2) Behavior subspace => do RRR
    E_B= do_RRR_face_to_neural(model, dataset, rank=32, device=device)  # shape(229, rank)
    print("E_B shape:", E_B.shape)

    # 3) build test_data => shape(229, T). 
    # for demonstration, let's just gather the entire dataset in raw domain
    # (some code is repeated from do_RRR, but we'll do it again for clarity)
    means= dataset.means.cpu().numpy()  # shape(229,)
    stds = dataset.stds.cpu().numpy()
    loader= DataLoader(dataset, batch_size=128, shuffle=False)
    all_preds=[]
    with torch.no_grad():
        for x_batch,y_batch_norm in loader:
            x_batch= x_batch.to(device)
            preds_norm= model(x_batch) # shape(B,229)
            preds_raw= preds_norm* torch.tensor(stds, device=device) \
                        + torch.tensor(means, device=device)
            all_preds.append(preds_raw.cpu().numpy())
    test_data= np.concatenate(all_preds, axis=0).T  # shape(229, T)
    print("test_data shape:", test_data.shape)

    # 4) find shared dimension => measure variance
    var_shared, var_stim, var_beh, total_var, shared_dim= \
        find_shared_and_measure_variance(R_stim, E_B, test_data, sub_dims=32)
    print(f"Variance in shared dimension: {var_shared:.4f}")
    print(f"Variance in stim-only subspace: {var_stim:.4f}")
    print(f"Variance in beh-only subspace: {var_beh:.4f}")
    print(f"Total variance in test data: {total_var:.4f}")
    # optional ratio
    print("Fraction of total variance in shared:", var_shared/ total_var)
    print("Fraction of total in stim-only:", var_stim/ total_var)
    print("Fraction of total in beh-only:", var_beh/ total_var)

    # You can also do classification or further analysis in each subspace, 
    # replicate the snippet's approach for "multiplicative gain", etc.
    # ...

###############################################################################
# USAGE EXAMPLE
###############################################################################
def main_analysis(model, dataset, device='cuda'):
    """
    After you train your MLP, call this for the advanced subspace approach.
    """
    full_subspace_analysis(model, dataset, device=device)


In [None]:
main_analysis(model, ds_full, device=device)

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

###############################################################################
# 0) Helper: invert MLP's normalized outputs => raw domain
###############################################################################
def invert_normalization(y_norm, ds):
    """
    y_norm: shape(..., 229)
    ds.means, ds.stds => shape(229,)
    Returns raw domain neural (229D).
    """
    device= y_norm.device
    means= ds.means.to(device)
    stds= ds.stds.to(device)
    return y_norm* stds + means

###############################################################################
# 1) Build Face Subspace from MLP
###############################################################################
def build_face_subspace(model, ds, 
                        face_samples=500, 
                        subspace_dim=10, 
                        device='cuda'):
    """
    We'll sample many face vectors from ds, set stim=0 => get MLP's predicted neural => raw domain => PCA => top subspace_dim PCs => face_subspace.

    Steps:
      1) sample random or systematic face inputs from ds
      2) x_mod => shape(504,) => x_mod[500:]=0 => no stim
      3) MLP => predicted neural in normalized domain => invert => raw domain => collect
      4) PCA => top subspace_dim => face_subspace shape(229, subspace_dim)
    """
    model.eval()
    N= len(ds)
    # pick random face_samples from ds
    idxs= np.random.choice(N, face_samples, replace=False)

    means= ds.means.to(device)
    stds= ds.stds.to(device)

    face_outputs=[]
    for i in idxs:
        x_raw, _= ds[i]  # shape(504,)
        x_mod= x_raw.clone()
        # zero out stimulus
        x_mod[500:]= 0.0
        x_batch= x_mod.unsqueeze(0).to(device)
        with torch.no_grad():
            y_pred_norm= model(x_batch)  # (1,229)
        y_pred_raw= invert_normalization(y_pred_norm, ds)  # (1,229)
        face_outputs.append(y_pred_raw.cpu().numpy()[0])  # shape(229,)

    face_arr= np.stack(face_outputs, axis=0)  # shape(face_samples,229)
    # do PCA => top subspace_dim
    pca= PCA(n_components=subspace_dim)
    pca.fit(face_arr)  # shape(#samples, 229)
    # subspace => shape(229, subspace_dim)
    face_subspace= pca.components_.T
    return face_subspace


###############################################################################
# 2) Build Stimulus Subspace from MLP
###############################################################################
def build_stim_subspace(model, ds, 
                        stim_samples=500,
                        subspace_dim=10,
                        device='cuda'):
    """
    We'll sample many stimulus vectors from ds, set face=0 => get MLP's predicted neural => raw domain => PCA => top subspace_dim => stim_subspace.
    """
    model.eval()
    N= len(ds)
    idxs= np.random.choice(N, stim_samples, replace=False)

    means= ds.means.to(device)
    stds= ds.stds.to(device)

    stim_outputs=[]
    for i in idxs:
        x_raw,_= ds[i]
        x_mod= x_raw.clone()
        # zero out face => x_mod[:500]=0
        x_mod[:500]= 0.0
        x_batch= x_mod.unsqueeze(0).to(device)
        with torch.no_grad():
            y_pred_norm= model(x_batch)
        y_pred_raw= invert_normalization(y_pred_norm, ds)  # shape(1,229)
        stim_outputs.append(y_pred_raw.cpu().numpy()[0])

    stim_arr= np.stack(stim_outputs, axis=0)  # (stim_samples,229)
    pca= PCA(n_components=subspace_dim)
    pca.fit(stim_arr)
    stim_subspace= pca.components_.T  # (229, subspace_dim)
    return stim_subspace


###############################################################################
# 3) Overlap / Shared Dimension + Subspace Variance
###############################################################################
def orthonormalize_subspace(subspace):
    """
    subspace: shape(229, d).
    We'll do a QR or SVD to ensure columns are orthonormal. Returns shape(229, d).
    """
    # shape(229, d)
    Q,_= np.linalg.qr(subspace)
    return Q

def project_out_subspace(X, Y):
    """
    Removes from X all components that lie in Y, assuming Y is orthonormal. 
    X shape(229, dX), Y shape(229, dY).
    Returns X - Y(Y^T X).
    """
    proj= Y@(Y.T@ X)
    return X- proj

def measure_variance_in_data(data, basis):
    """
    data shape(229,T), basis shape(229,d).
    We'll project => basis^T data => shape(d,T), sum squares / T => variance
    """
    coords= basis.T@ data
    var_ = np.sum(coords**2)/ data.shape[1]
    return var_

def find_shared_dimension(face_subspace, stim_subspace):
    """
    face_subspace => shape(229, df), stim_subspace => shape(229, ds). 
    We'll do M= (face_subspace^T) (stim_subspace) => shape(df, ds).
    SVD => top singular vector => direction in face_subspace that best overlaps stim_subspace.
    Then in 229D => face_subspace @ U[:,0], etc.
    We'll keep top 1 dimension for demonstration. 
    """
    # ensure both subspaces are orthonormal
    F_ortho= orthonormalize_subspace(face_subspace)
    S_ortho= orthonormalize_subspace(stim_subspace)
    M= F_ortho.T@ S_ortho  # shape(df, ds)
    U,s, Vt= np.linalg.svd(M, full_matrices=False)
    # top left singular vector => U[:,0] in face subspace
    u_face= U[:,0]  # shape(df,)
    # in 229D => F_ortho@ u_face
    shared_dim= F_ortho@ u_face  # shape(229,)
    # normalize
    nrm= np.linalg.norm(shared_dim)
    if nrm> 1e-12:
        shared_dim/= nrm
    return shared_dim.reshape(-1,1)  # shape(229,1)

def analyze_subspaces(model, ds, device='cuda',
                      face_dim=10, stim_dim=10,
                      face_samples=500, stim_samples=500,
                      test_samples=None,
                      do_null_test=True,
                      random_seed=42):
    """
    1) build face_subspace, stim_subspace
    2) find shared dimension
    3) define face-only, stim-only => project out shared
    4) measure variance in test_data
    5) if do_null_test => build random subspace to see typical overlap
    """
    if random_seed is not None:
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

    face_sub= build_face_subspace(model, ds, face_samples, face_dim, device=device)
    stim_sub= build_stim_subspace(model, ds, stim_samples, stim_dim, device=device)
    shared_dim= find_shared_dimension(face_sub, stim_sub)  # shape(229,1)

    # face-only => face_sub - shared
    face_sub_ortho= orthonormalize_subspace(face_sub)
    face_only= project_out_subspace(face_sub_ortho, shared_dim)
    face_only= orthonormalize_subspace(face_only)

    # stim-only => stim_sub - shared
    stim_sub_ortho= orthonormalize_subspace(stim_sub)
    stim_only= project_out_subspace(stim_sub_ortho, shared_dim)
    stim_only= orthonormalize_subspace(stim_only)

    print(f"face_sub: {face_sub.shape}, stim_sub: {stim_sub.shape}, shared_dim => shape(229,1)")

    # 4) measure variance in test_data => shape(229,T)
    if test_samples is None:
        # gather entire ds, predict MLP => raw domain
        # build => shape(229, T)
        loader= DataLoader(ds, batch_size=128, shuffle=False)
        all_preds=[]
        means= ds.means.to(device)
        stds= ds.stds.to(device)
        model.eval()
        with torch.no_grad():
            for x_batch,yb in loader:
                x_batch= x_batch.to(device)
                preds_norm= model(x_batch)
                preds_raw= preds_norm* stds + means
                all_preds.append(preds_raw.cpu().numpy())
        test_data_np= np.concatenate(all_preds, axis=0).T  # shape(229, T)
    else:
        # user can pass their own test_data => shape(229,T)
        test_data_np= test_samples

    total_var= np.sum(test_data_np**2)/ test_data_np.shape[1]
    var_shared= measure_variance_in_data(test_data_np, shared_dim)
    var_faceonly= measure_variance_in_data(test_data_np, face_only)
    var_stimonly= measure_variance_in_data(test_data_np, stim_only)

    sum_ = var_shared+ var_faceonly+ var_stimonly

    print(f"Total variance= {total_var:.4f}")
    print(f"Shared variance= {var_shared:.4f} => {100*var_shared/total_var:.2f}%")
    print(f"Face-only variance= {var_faceonly:.4f} => {100*var_faceonly/total_var:.2f}%")
    print(f"Stim-only variance= {var_stimonly:.4f} => {100*var_stimonly/total_var:.2f}%")
    print(f"Sum of subspace var= {sum_:.4f} => leftover= {total_var- sum_:.4f}")

    # 5) optional: do null test => random subspaces
    if do_null_test:
        # e.g. build random subspace shape(229, face_dim), random subspace shape(229, stim_dim), measure overlap
        # see typical angle or overlap
        # or do random labeling approach

        # We'll just do random subspace approach
        def random_subspace(dim):
            mat= np.random.randn(229,dim)
            Q,_= np.linalg.qr(mat)
            return Q

        rand_face= random_subspace(face_dim)
        rand_stim= random_subspace(stim_dim)
        # measure overlap => we can do a simple measure => \|rand_face^T rand_stim\|_F^2
        overlap_null= np.linalg.norm(rand_face.T@ rand_stim, 'fro')**2
        print(f"Null test: random subspace overlap= {overlap_null:.4f} (dims= {face_dim}, {stim_dim})")

        # we can do repeated times to get a distribution if we want
        # ...

    # done
    return {
        'face_sub': face_sub,
        'stim_sub': stim_sub,
        'shared_dim': shared_dim,
        'face_only': face_only,
        'stim_only': stim_only,
        'var_shared': var_shared,
        'var_faceonly': var_faceonly,
        'var_stimonly': var_stimonly,
        'total_var': total_var
    }


In [None]:
analyze_subspaces(model, ds_full, device=device)

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import scipy.io

def invert_normalization(pred_norm, ds):
    """
    Convert predicted normalized (N,229) back to raw domain using ds.means, ds.stds.
    """
    device = pred_norm.device
    means = ds.means.to(device)   # shape(229,)
    stds  = ds.stds.to(device)
    return pred_norm * stds + means

def analyze_face_stim_orthogonality(model, ds_full, 
                                    M=100, K=10,
                                    sub_pcs=3,
                                    device='cuda',
                                    random_seed=None,
                                    null_hypothesis_mode=False):
    """
    1) Baseline => x_base=0 => model => y_base
    2) Face-driven => keep face from real sample, set stim=0 (last 4 dims)
    3) Stim-driven => keep stim from real sample, set face to the mean of the first 500 dims
    4) Possibly do a 'null hypothesis' mode: e.g. face vs. face
    5) Measure subspace overlap
    6) 2D PCA plot
    7) Save the predicted face and stim outputs to .npy and .mat files.
    """

    model.eval()
    if random_seed is not None:
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

    # Baseline
    x_base = torch.zeros(1, 504, dtype=torch.float32).to(device)
    with torch.no_grad():
        y_base_norm = model(x_base)[0]  # shape(229,)

    N = len(ds_full)
    
    if not null_hypothesis_mode:
        # normal scenario => face vs. stim
        idxs = np.random.choice(N, M + K, replace=False)
        face_idxs = idxs[:M]
        stim_idxs = idxs[M:]

        # face-driven
        x_face_list = []
        for i in face_idxs:
            x_raw, _ = ds_full[i]
            x_mod = x_raw.clone()
            # zero out stim => last 4
            x_mod[500:] = 0.0
            x_face_list.append(x_mod)
        x_face_all = torch.stack(x_face_list, dim=0).to(device)

        with torch.no_grad():
            y_face_norm = model(x_face_all)
        y_face_norm = y_face_norm - y_base_norm

        # stim-driven
        x_stim_list = []
        for j in stim_idxs:
            x_raw, _ = ds_full[j]
            x_mod = x_raw.clone()
            # Replace the first 500 dims (face) with their mean
            mean_face = x_mod[:500].mean()
            x_mod[:500] = mean_face
            x_stim_list.append(x_mod)
        x_stim_all = torch.stack(x_stim_list, dim=0).to(device)

        with torch.no_grad():
            y_stim_norm = model(x_stim_all)
        y_stim_norm = y_stim_norm - y_base_norm

        # invert to raw
        y_face_raw = invert_normalization(y_face_norm, ds_full)
        y_stim_raw = invert_normalization(y_stim_norm, ds_full)

        face_np = y_face_raw.cpu().numpy()
        stim_np = y_stim_raw.cpu().numpy()

        label_face = 'Face-driven'
        label_stim = 'Stim-driven'

        # --- Save to npy and mat files ---
        np.save('face_predictions.npy', face_np)
        np.save('stim_predictions.npy', stim_np)

        scipy.io.savemat('face_predictions.mat', {'face_predictions': face_np})
        scipy.io.savemat('stim_predictions.mat', {'stim_predictions': stim_np})

    else:
        # Null hypothesis => let's do face vs. face
        idxs = np.random.choice(N, M + K, replace=False)
        setA_idxs = idxs[:M]
        setB_idxs = idxs[M:]

        xA_list = []
        for i in setA_idxs:
            x_raw, _ = ds_full[i]
            x_mod = x_raw.clone()
            x_mod[500:] = 0.0
            xA_list.append(x_mod)
        xA_all = torch.stack(xA_list, dim=0).to(device)
        with torch.no_grad():
            yA_norm = model(xA_all)
        yA_norm = yA_norm - y_base_norm
        yA_raw = invert_normalization(yA_norm, ds_full)
        setA_np = yA_raw.cpu().numpy()

        xB_list = []
        for j in setB_idxs:
            x_raw, _ = ds_full[j]
            x_mod = x_raw.clone()
            x_mod[500:] = 0.0
            xB_list.append(x_mod)
        xB_all = torch.stack(xB_list, dim=0).to(device)
        with torch.no_grad():
            yB_norm = model(xB_all)
        yB_norm = yB_norm - y_base_norm
        yB_raw = invert_normalization(yB_norm, ds_full)
        setB_np = yB_raw.cpu().numpy()

        face_np = setA_np
        stim_np = setB_np
        label_face = 'Set A (face/fake)'
        label_stim = 'Set B (face/fake)'

        # (optional) If you want to save the null-hypothesis sets as well:
        np.save('setA_null_predictions.npy', face_np)
        np.save('setB_null_predictions.npy', stim_np)
        scipy.io.savemat('setA_null_predictions.mat', {'setA_null_predictions': face_np})
        scipy.io.savemat('setB_null_predictions.mat', {'setB_null_predictions': stim_np})

    # measure subspace overlap
    pca_face = PCA(n_components=sub_pcs).fit(face_np)
    U_face = pca_face.components_.T
    pca_stim = PCA(n_components=sub_pcs).fit(stim_np)
    U_stim = pca_stim.components_.T

    overlap = np.linalg.norm(U_face.T @ U_stim, 'fro') ** 2
    print(f"Subspace overlap (top {sub_pcs} PCs) = {overlap:.4f}")

    # angle top1
    face_pc1 = U_face[:, 0]
    stim_pc1 = U_stim[:, 0]
    dot = np.dot(face_pc1, stim_pc1)
    denom = np.linalg.norm(face_pc1) * np.linalg.norm(stim_pc1) + 1e-12
    angle_deg = np.degrees(np.arccos(dot / denom))
    print(f"Angle between top1 PC: {angle_deg:.2f} deg")

    # 2D PCA
    all_data = np.concatenate([face_np, stim_np], axis=0)
    pca_2d = PCA(n_components=2)
    all_2d = pca_2d.fit_transform(all_data)
    M_ = face_np.shape[0]

    plt.figure(figsize=(6,5))
    plt.scatter(all_2d[:M_, 0], all_2d[:M_, 1], alpha=0.6, label=label_face)
    plt.scatter(all_2d[M_:, 0], all_2d[M_:, 1], alpha=0.6, label=label_stim)
    plt.title("Face vs. Stimz Subspace" if not null_hypothesis_mode else "Null Hypothesis (Face vs. Face)")
    plt.legend()
    plt.show()

In [None]:
analyze_face_stim_orthogonality(model, ds_full, 
                                    M=1000, K=500,
                                    sub_pcs=3,
                                    device=device)

In [None]:
import numpy as np
import torch

def invert_normalization(pred_norm, ds):
    """
    Convert predicted normalized (N,229) back to raw domain 
    using ds.means, ds.stds from the training dataset.
    """
    device = pred_norm.device
    means = ds.means.to(device)   # shape(229,)
    stds  = ds.stds.to(device)
    return pred_norm * stds + means


def compute_average_face(ds_full):
    """
    Compute the average face vector (dim=500) across ALL samples in ds_full.
    ds_full[i] returns (x_in, y_in) with x_in of size (504,). The first 500 
    are face, the last 4 are stimulus.
    """
    all_faces = []
    for i in range(len(ds_full)):
        x_in, _ = ds_full[i]       # x_in: shape(504,)
        face_part = x_in[:500]     # shape(500,)
        all_faces.append(face_part)
    all_faces = torch.stack(all_faces, dim=0)  # (N, 500)
    avg_face = all_faces.mean(dim=0)           # (500,)
    return avg_face


def predict_stim_no_face_motion(model, ds_full, device='cuda'):
    """
    1) Compute f(stimulus, no-face-motion).
       - no-face-motion = average of face dims across entire dataset
       - stimulus = one-hot for each of the 4 conditions
    2) Return predictions in shape (229, 4).
    """
    model.eval()
    avg_face = compute_average_face(ds_full).to(device)  # (500,)

    # We have 4 stimulus conditions => one-hot vectors
    # [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]
    stim_inputs = []
    for c in range(4):
        x_in = torch.zeros(504, dtype=torch.float32, device=device)
        x_in[:500] = avg_face
        x_in[500 + c] = 1.0  # one-hot
        stim_inputs.append(x_in)
    stim_inputs = torch.stack(stim_inputs, dim=0)  # (4, 504)

    with torch.no_grad():
        pred_norm = model(stim_inputs)  # (4, 229)
        pred_raw = invert_normalization(pred_norm, ds_full)  # (4, 229)

    # Transpose to shape (229,4) so each column is a different stim condition
    pred_raw = pred_raw.cpu().numpy().T
    return pred_raw


def predict_no_stim_face_motion(model, eigenface_evoked, ds_full, device='cuda'):
    """
    1) Compute f(no-stimulus, face-motion).
       - no-stimulus = [0,0,0,0] in the last 4 dims
       - face-motion = all columns from the 4 evoked conditions (4 * 1000 = 4000)
    2) Return predictions in shape (229, 4000).
    """
    model.eval()

    # We assume eigenface_evoked has shape (500, 1000, 4)
    # i.e. 500 face dims, 1000 time columns, 4 stim conditions
    # We'll gather all face columns => total 4000

    face_motion_list = []
    for c in range(4):
        face_block = eigenface_evoked[:, :, c]  # shape (500, 1000)
        # For each of the 1000 columns in this condition:
        for col in range(1000):
            face_col = face_block[:, col]  # shape (500,)
            x_in = torch.zeros(504, dtype=torch.float32)
            x_in[:500] = torch.from_numpy(face_col)
            # no stimulus => last 4 dims = 0
            face_motion_list.append(x_in)

    face_motion_tensor = torch.stack(face_motion_list, dim=0).to(device)  # (4000, 504)

    with torch.no_grad():
        pred_norm = model(face_motion_tensor)         # (4000, 229)
        pred_raw = invert_normalization(pred_norm, ds_full)  # (4000, 229)

    # shape => (4000, 229), transpose => (229, 4000)
    pred_raw = pred_raw.cpu().numpy().T
    return pred_raw


def main(model, ds_full, eigenface_evoked, device='cuda'):
    """
    Drives the two predictions and saves them to disk.
    """
    # 1) f(stimulus, no-face-motion)
    stim_no_face = predict_stim_no_face_motion(model, ds_full, device=device)
    print("stim_no_face shape:", stim_no_face.shape)  # should be (229, 4)
    np.save("pred_stim_no_face_motion.npy", stim_no_face)

    # 2) f(no-stimulus, face-motion)
    no_stim_face = predict_no_stim_face_motion(model, eigenface_evoked, ds_full, device=device)
    print("no_stim_face shape:", no_stim_face.shape)  # should be (229, 4000)
    np.save("pred_no_stim_face_motion.npy", no_stim_face)


if __name__ == "__main__":
    # Example usage:  (assuming you already have model, ds_full, eigenface_evoked from prior steps)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # model, ds_full = ...  # from your training code
    # eigenface_evoked = ...  # loaded from the .mat file

    # Call our main routine
    main(model, ds_full, eigenface_evoked, device=device)