In [None]:
!pip install --quiet torch torchvision --disable-pip-version-check


In [None]:
import os, time, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity

WORKDIR = '/content/drive/MyDrive/fmri_fingerprint'
RESULTS_DIR = os.path.join(WORKDIR, 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


In [None]:
batch_size = 16
epochs = 120
lr = 1e-3
embedding_dim = 64
temperature = 0.1
topK = 500
seed = 42

torch.manual_seed(seed)
np.random.seed(seed)


In [None]:
# data prep (construct v1, v2 if needed)
try:
    v1; v2
    print("Using existing v1/v2 variables.")
except Exception:
    try:
        Z1_arr
        Z2_arr
        iu
        print("Constructing v1/v2 from Z1_arr/Z2_arr and iu.")
        v1 = Z1_arr[:, iu[0], iu[1]].astype(np.float32)
        v2 = Z2_arr[:, iu[0], iu[1]].astype(np.float32)
    except Exception as e:
        raise RuntimeError("Could not find v1/v2 nor Z1_arr/Z2_arr+iu. Run split-half cells first.") from e

n_subj, n_edges = v1.shape
print(f"n_subj={n_subj}, n_edges={n_edges}")


In [None]:
use_idx = None
if (('icc_vals' in globals() or 'icc_vals' in locals()) and topK is not None):
    try:
        order = np.argsort(icc_vals)[::-1]
        selected_idx = order[:min(topK, len(order))]
        v1_sel = v1[:, selected_idx]
        v2_sel = v2[:, selected_idx]
        use_idx = selected_idx
        print(f"Using top-{len(selected_idx)} ICC edges (selected by icc_vals).")
    except Exception as e:
        print("Failed to use icc_vals selection; falling back to all edges.", e)
        v1_sel = v1
        v2_sel = v2
else:
    v1_sel = v1
    v2_sel = v2
    print("Using all edges (no ICC-based selection).")


In [None]:
from sklearn.preprocessing import StandardScaler
X_all_views = np.vstack([v1_sel, v2_sel])  # 2N x F
scaler = StandardScaler()
X_all_views = scaler.fit_transform(X_all_views)
v1_norm = X_all_views[:n_subj]
v2_norm = X_all_views[n_subj:]


In [None]:
class SplitHalfPairs(Dataset):
    def __init__(self, v1_array, v2_array):
        assert v1_array.shape == v2_array.shape
        self.v1 = torch.tensor(v1_array, dtype=torch.float32)
        self.v2 = torch.tensor(v2_array, dtype=torch.float32)
        self.N = self.v1.shape[0]
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.v1[idx], self.v2[idx], idx

dataset = SplitHalfPairs(v1_norm, v2_norm)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


In [None]:
class MLPEncoder(nn.Module):
    def __init__(self, input_dim, emb_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, emb_dim)
        )
    def forward(self, x):
        z = self.net(x)
        z = F.normalize(z, p=2, dim=1)
        return z

input_dim = v1_norm.shape[1]
model = MLPEncoder(input_dim, embedding_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)


In [None]:
def nt_xent_loss(z1, z2, temp=0.1):
    B = z1.shape[0]
    z = torch.cat([z1, z2], dim=0)
    sim = torch.matmul(z, z.T) / temp
    diag_mask = torch.eye(2*B, device=z.device).bool()
    sim_masked = sim.masked_fill(diag_mask, -9e15)
    log_prob = sim_masked - torch.logsumexp(sim_masked, dim=1, keepdim=True)
    pos_logprob = torch.cat([log_prob[torch.arange(B), torch.arange(B)+B],
                             log_prob[torch.arange(B)+B, torch.arange(B)]], dim=0)
    loss = - pos_logprob.mean()
    return loss


In [None]:
best_loss = 1e9
patience = 20
wait = 0
start_time = time.time()
for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    for v1_batch, v2_batch, idxs in loader:
        v1_batch = v1_batch.to(device)
        v2_batch = v2_batch.to(device)
        z1 = model(v1_batch)
        z2 = model(v2_batch)
        loss = nt_xent_loss(z1, z2, temp=temperature)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * v1_batch.shape[0]
    avg_loss = total_loss / len(loader.dataset)
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d}  avg_loss={avg_loss:.4f}  time_elapsed={(time.time()-start_time):.1f}s")
    if avg_loss < best_loss:
        best_loss = avg_loss
        wait = 0
        torch.save(model.state_dict(), os.path.join(RESULTS_DIR, 'siamese_encoder_best.pt'))
    else:
        wait += 1
    if wait >= patience:
        print("Early stopping (patience reached).")
        break


In [None]:
model.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'siamese_encoder_best.pt')))
model.eval()
with torch.no_grad():
    v1_tensor = torch.tensor(v1_norm, dtype=torch.float32).to(device)
    v2_tensor = torch.tensor(v2_norm, dtype=torch.float32).to(device)
    emb1 = model(v1_tensor).cpu().numpy()
    emb2 = model(v2_tensor).cpu().numpy()

sim_mat = cosine_similarity(emb1, emb2)
preds = sim_mat.argmax(axis=1)
accuracy = (preds == np.arange(n_subj)).mean()
print(f"Siamese NN identification accuracy (embedding NN): {accuracy:.3f} ({int(accuracy*100)}%)")

np.save(os.path.join(RESULTS_DIR, 'siamese_emb_view1.npy'), emb1)
np.save(os.path.join(RESULTS_DIR, 'siamese_emb_view2.npy'), emb2)
np.save(os.path.join(RESULTS_DIR, 'siamese_sim_matrix.npy'), sim_mat)
with open(os.path.join(RESULTS_DIR, 'siamese_report.txt'), 'w') as f:
    f.write(f"device: {device}\n")
    f.write(f"n_subj: {n_subj}\n")
    f.write(f"n_edges_used: {v1_sel.shape[1]}\n")
    f.write(f"embedding_dim: {embedding_dim}\n")
    f.write(f"topK_used (if any): {topK if use_idx is not None else 'all'}\n")
    f.write(f"accuracy: {accuracy}\n")
print("Saved siamese artifacts to:", RESULTS_DIR)


In [None]:
try:
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    Z12 = np.vstack([emb1, emb2])
    Zp = pca.fit_transform(Z12)
    Zp1 = Zp[:n_subj]; Zp2 = Zp[n_subj:]
    plt.figure(figsize=(6,6))
    plt.scatter(Zp1[:,0], Zp1[:,1], label='view1', alpha=0.9)
    plt.scatter(Zp2[:,0], Zp2[:,1], label='view2', marker='x', alpha=0.9)
    for i in range(n_subj):
        plt.plot([Zp1[i,0], Zp2[i,0]], [Zp1[i,1], Zp2[i,1]], 'k-', alpha=0.3)
    plt.title(f"Siamese embeddings (PCA 2D); NN acc={accuracy:.3f}")
    plt.legend()
    plt.savefig(os.path.join(RESULTS_DIR,'siamese_embeddings_pca2d.png'), dpi=150, bbox_inches='tight')
    plt.show()
except Exception as e:
    print("Skipping PCA viz (matplotlib/PCA issue):", e)
