In [14]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import random
import os
print(os.getcwd())
# CONFIGURATION
SEED = 42
BATCH_SIZE = 8
PROJ_DIM = 64     # Dimension of contrastive space
LR = 1e-3
EPOCHS = 100
TEMPERATURE = 0.5
GENE_FILE = '../../gene1_count.xlsx'
CHROM_FILE = '../../chrom1_count.xlsx'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cpu')

# DATA LOADING
gene_xl = pd.ExcelFile(GENE_FILE)
chrom_xl = pd.ExcelFile(CHROM_FILE)
gene_sheets = [s for s in gene_xl.sheet_names if 'Frequently' not in s]
chrom_sheets = chrom_xl.sheet_names
cancers = sorted(set(gene_sheets) & set(chrom_sheets))

gene_feats, chrom_feats = [], []
for c in cancers:
    df_g = gene_xl.parse(c).select_dtypes(include=np.number)
    df_c = chrom_xl.parse(c).select_dtypes(include=np.number)
    gene_feats.append(df_g.values.flatten())
    chrom_feats.append(df_c.values.flatten())

gene_X = np.stack(gene_feats)
chrom_X = np.stack(chrom_feats)

class MultiViewCancerDataset(Dataset):
    def __init__(self, gene, chrom):
        self.gene = torch.tensor(gene, dtype=torch.float32)
        self.chrom = torch.tensor(chrom, dtype=torch.float32)
    def __len__(self):
        return len(self.gene)
    def __getitem__(self, idx):
        return self.gene[idx], self.chrom[idx]

dataloader = DataLoader(
    MultiViewCancerDataset(gene_X, chrom_X),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Projection Heads operating directly on raw data
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=64):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, out_dim)
        )
    def forward(self, x):
        return self.proj(x)

proj_gene = ProjectionHead(gene_X.shape[1], PROJ_DIM).to(device)
proj_chrom = ProjectionHead(chrom_X.shape[1], PROJ_DIM).to(device)

def nt_xent_loss(z1, z2, temp=TEMPERATURE):
    # z1, z2: (batch, PROJ_DIM)
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)  # (2N, PROJ_DIM)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp
    mask = torch.eye(2 * N, device=z.device).bool()
    sim = sim.masked_fill(mask, -9e15)
    labels = torch.arange(N, device=z.device)
    labels = torch.cat([labels + N, labels], dim=0)
    return F.cross_entropy(sim, labels)

optimizer = torch.optim.Adam(
    list(proj_gene.parameters()) + list(proj_chrom.parameters()),
    lr=LR
)

# TRAINING LOOP
for epoch in range(1, EPOCHS + 1):
    proj_gene.train(); proj_chrom.train()
    total_loss = 0.0
    for gv, cv in dataloader:
        gv = gv.to(device)
        cv = cv.to(device)

        # Directly project raw features into contrastive space
        z_g = proj_gene(gv)
        z_c = proj_chrom(cv)

        loss = nt_xent_loss(z_g, z_c)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch}/{EPOCHS} Loss: {total_loss/len(dataloader):.4f}")

# SAVE EMBEDDINGS (in the aligned space)
with torch.no_grad():
    gene_all = torch.tensor(gene_X, dtype=torch.float32).to(device)
    chrom_all = torch.tensor(chrom_X, dtype=torch.float32).to(device)

    z_g = proj_gene(gene_all)
    z_c = proj_chrom(chrom_all)

    embeds = ((z_g + z_c) / 2).cpu().numpy()
    np.save("cancer_embeddings_raw_proj.npy", embeds)

print("Saved embeddings to cancer_embeddings_raw_proj.npy")


/Users/yifandou/Library/Mobile Documents/com~apple~CloudDocs/Desktop/25Fall/Github code and data/Ablation_Study/No_TabNet
Epoch 1/100 Loss: 2.7993
Epoch 2/100 Loss: 2.5637
Epoch 3/100 Loss: 2.5361
Epoch 4/100 Loss: 2.5295
Epoch 5/100 Loss: 2.5278
Epoch 6/100 Loss: 2.5267
Epoch 7/100 Loss: 2.5266
Epoch 8/100 Loss: 2.5261
Epoch 9/100 Loss: 2.5252
Epoch 10/100 Loss: 2.5250
Epoch 11/100 Loss: 2.5251
Epoch 12/100 Loss: 2.5255
Epoch 13/100 Loss: 2.5251
Epoch 14/100 Loss: 2.5248
Epoch 15/100 Loss: 2.5250
Epoch 16/100 Loss: 2.5255
Epoch 17/100 Loss: 2.5248
Epoch 18/100 Loss: 2.5249
Epoch 19/100 Loss: 2.5263
Epoch 20/100 Loss: 2.5251
Epoch 21/100 Loss: 2.5250
Epoch 22/100 Loss: 2.5251
Epoch 23/100 Loss: 2.5256
Epoch 24/100 Loss: 2.5249
Epoch 25/100 Loss: 2.5250
Epoch 26/100 Loss: 2.5251
Epoch 27/100 Loss: 2.5248
Epoch 28/100 Loss: 2.5249
Epoch 29/100 Loss: 2.5251
Epoch 30/100 Loss: 2.5254
Epoch 31/100 Loss: 2.5257
Epoch 32/100 Loss: 2.5254
Epoch 33/100 Loss: 2.5252
Epoch 34/100 Loss: 2.5251
Epo