In [4]:
# ============================================================
# 0. Imports & Config
# ============================================================

import os
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report

from xgboost import XGBClassifier

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Adjust to your folder
DATA_DIR = "/kaggle/input/tcga-data/data/processed"

RNA_FILE   = "rnaseq_aligned.csv"          # rows = samples, cols = genes
METH_FILE  = "methylation_aligned.csv"  # rows = samples, cols = CpGs
CLIN_FILE  = "clinical_aligned.csv"     # rows = samples, cols includes stage_filled or stage_best

Using device: cpu


In [5]:
# ============================================================
# 1. Load data
# ============================================================

rna   = pd.read_csv(os.path.join(DATA_DIR, RNA_FILE), index_col=0)
meth  = pd.read_csv(os.path.join(DATA_DIR, METH_FILE), index_col=0)
clin  = pd.read_csv(os.path.join(DATA_DIR, CLIN_FILE), index_col=0)

print("RNA shape:       ", rna.shape)
print("Methyl shape:    ", meth.shape)
print("Clinical shape:  ", clin.shape)

RNA shape:        (1092, 60660)
Methyl shape:     (1092, 14165)
Clinical shape:   (1092, 7)


In [6]:
# ============================================================
# 2. Normalize tumor stage labels (I / II / III / IV)
# ============================================================

def normalize_stage(s):
    if pd.isna(s):
        return np.nan
    s = str(s).lower()
    if "iv" in s:
        return "IV"
    if "iii" in s:
        return "III"
    if "ii" in s:
        return "II"
    if "i" in s:
        return "I"
    return np.nan

if "stage_filled" in clin.columns:
    raw_stage = clin["stage_filled"]
elif "stage_best" in clin.columns:
    raw_stage = clin["stage_best"]
else:
    raise ValueError("No stage_filled or stage_best column in clinical_aligned.csv")

labels = raw_stage.apply(normalize_stage)
print("Stage distribution BEFORE cleaning:")
print(labels.value_counts(dropna=False))

Stage distribution BEFORE cleaning:
stage_best
II     620
III    249
I      179
NaN     24
IV      20
Name: count, dtype: int64


In [7]:
# ============================================================
# 3. Align samples & drop NaN stages
# ============================================================

common_samples = set(rna.index) & set(meth.index) & set(labels.index)
common_samples = sorted(common_samples)

rna    = rna.loc[common_samples]
meth   = meth.loc[common_samples]
labels = labels.loc[common_samples]

mask = labels.notna()
rna    = rna.loc[mask]
meth   = meth.loc[mask]
labels = labels.loc[mask]

print("\nAfter alignment & dropping NaN stages:")
print("RNA:", rna.shape, "Methyl:", meth.shape, "Labels:", labels.shape)
print(labels.value_counts())


After alignment & dropping NaN stages:
RNA: (1068, 60660) Methyl: (1068, 14165) Labels: (1068,)
stage_best
II     620
III    249
I      179
IV      20
Name: count, dtype: int64


In [8]:
# ============================================================
# 4. Preprocess RNA & methylation (impute + standardize)
# ============================================================

# Impute per-feature median
rna_imputed  = rna.apply(lambda col: col.fillna(col.median()), axis=0)
meth_imputed = meth.apply(lambda col: col.fillna(col.median()), axis=0)

# Standardize (per modality)
rna_scaler  = StandardScaler()
meth_scaler = StandardScaler()

X_rna  = rna_scaler.fit_transform(rna_imputed.values)   # (n_samples, n_genes)
X_meth = meth_scaler.fit_transform(meth_imputed.values) # (n_samples, n_cpgs)

print("\nScaled shapes:", X_rna.shape, X_meth.shape)


Scaled shapes: (1068, 60660) (1068, 14165)


In [9]:
# ============================================================
# 5. Contrastive dataset
# ============================================================

class PairedOmicsDataset(Dataset):
    """
    Holds paired RNA + methylation samples in the same order.
    """
    def __init__(self, X_rna, X_meth):
        assert X_rna.shape[0] == X_meth.shape[0]
        self.X_rna = X_rna.astype(np.float32)
        self.X_meth = X_meth.astype(np.float32)

    def __len__(self):
        return self.X_rna.shape[0]

    def __getitem__(self, idx):
        return self.X_rna[idx], self.X_meth[idx]

dataset = PairedOmicsDataset(X_rna, X_meth)
loader  = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)

In [10]:
# ============================================================
# 6. Encoders & contrastive loss (InfoNCE / NT-Xent style)
# ============================================================

class MLPEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, proj_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim),
        )

    def forward(self, x):
        z = self.net(x)
        z = F.normalize(z, dim=1)  # L2 normalize embeddings
        return z


def contrastive_loss(z1, z2, temperature=0.1):
    """
    Symmetric InfoNCE loss for two views (z1, z2) of same batch.
    z1, z2: (batch_size, dim)
    """
    batch_size = z1.size(0)
    # Cosine similarity matrix: (2B, 2B)
    z = torch.cat([z1, z2], dim=0)               # (2B, d)
    sim = F.cosine_similarity(z.unsqueeze(1),    # (2B,1,d)
                              z.unsqueeze(0),    # (1,2B,d)
                              dim=2)             # (2B,2B)

    # Mask out self-similarity
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    sim.masked_fill_(mask, -9e15)

    # Similarities scaled by temperature
    sim = sim / temperature

    # For each i in [0..2B), the positive is:
    # i <-> i + B  (for first view) or i <-> i - B (for second view)
    targets = torch.cat([torch.arange(batch_size, 2 * batch_size),
                         torch.arange(0, batch_size)]).to(z.device)

    loss = F.cross_entropy(sim, targets)
    return loss

In [11]:
# ============================================================
# 7. Initialize encoders & train contrastive model
# ============================================================

input_dim_rna  = X_rna.shape[1]
input_dim_meth = X_meth.shape[1]
proj_dim       = 64

encoder_rna  = MLPEncoder(input_dim_rna,  hidden_dim=256, proj_dim=proj_dim).to(device)
encoder_meth = MLPEncoder(input_dim_meth, hidden_dim=256, proj_dim=proj_dim).to(device)

params = list(encoder_rna.parameters()) + list(encoder_meth.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=1e-5)

n_epochs = 100
temperature = 0.1

encoder_rna.train()
encoder_meth.train()

for epoch in range(1, n_epochs + 1):
    epoch_loss = 0.0
    for batch_rna, batch_meth in loader:
        batch_rna  = batch_rna.to(device)
        batch_meth = batch_meth.to(device)

        z_rna  = encoder_rna(batch_rna)
        z_meth = encoder_meth(batch_meth)

        loss = contrastive_loss(z_rna, z_meth, temperature=temperature)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * batch_rna.size(0)

    epoch_loss /= len(dataset)
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:3d}/{n_epochs}  Contrastive loss: {epoch_loss:.4f}")

Epoch   1/100  Contrastive loss: 2.9937
Epoch  10/100  Contrastive loss: 0.0642
Epoch  20/100  Contrastive loss: 0.0321
Epoch  30/100  Contrastive loss: 0.0310
Epoch  40/100  Contrastive loss: 0.0113
Epoch  50/100  Contrastive loss: 0.0457
Epoch  60/100  Contrastive loss: 0.0189
Epoch  70/100  Contrastive loss: 0.0135
Epoch  80/100  Contrastive loss: 0.0121
Epoch  90/100  Contrastive loss: 0.0711
Epoch 100/100  Contrastive loss: 0.0221


In [12]:
# ============================================================
# 8. Extract contrastive embeddings for all samples
# ============================================================

encoder_rna.eval()
encoder_meth.eval()

with torch.no_grad():
    X_rna_tensor  = torch.tensor(X_rna,  dtype=torch.float32).to(device)
    X_meth_tensor = torch.tensor(X_meth, dtype=torch.float32).to(device)

    z_rna_all  = encoder_rna(X_rna_tensor)   # (n_samples, proj_dim)
    z_meth_all = encoder_meth(X_meth_tensor) # (n_samples, proj_dim)

    z_rna_all  = z_rna_all.cpu().numpy()
    z_meth_all = z_meth_all.cpu().numpy()

# Simple fusion: average RNA + methyl embeddings
z_contrastive = (z_rna_all + z_meth_all) / 2.0
z_contrastive = np.nan_to_num(z_contrastive)

print("Contrastive embedding shape:", z_contrastive.shape)

Contrastive embedding shape: (1068, 64)


In [13]:
# ============================================================
# 9. XGBoost classifier on contrastive embeddings (tumor stage)
# ============================================================

def evaluate_xgb(X, labels, name):
    mask = labels.notna()
    X = X[mask]
    y = labels[mask]

    le = LabelEncoder()
    y_enc = le.fit_transform(y)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y_enc,
        test_size=0.2,
        random_state=SEED,
        stratify=y_enc
    )

    clf = XGBClassifier(
        n_estimators=400,
        max_depth=4,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=2.0,
        objective="multi:softprob",
        eval_metric="mlogloss",
        tree_method="hist",
        random_state=SEED
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    f1 = f1_score(y_test, y_pred, average="weighted")
    print(f"\n=== {name} ===")
    print("Weighted F1:", round(f1, 4))

    test_classes = np.unique(y_test)
    print(
        classification_report(
            y_test,
            y_pred,
            labels=test_classes,
            target_names=le.inverse_transform(test_classes),
            zero_division=0
        )
    )

    return clf, f1

xgb_model, f1 = evaluate_xgb(z_contrastive, labels, "Contrastive RNA+Meth embedding (XGB)")

print("\nFinal contrastive model F1:", round(f1, 4))


=== Contrastive RNA+Meth embedding (XGB) ===
Weighted F1: 0.4372
              precision    recall  f1-score   support

           I       0.12      0.03      0.05        36
          II       0.57      0.83      0.68       124
         III       0.23      0.12      0.16        50
          IV       0.00      0.00      0.00         4

    accuracy                           0.51       214
   macro avg       0.23      0.24      0.22       214
weighted avg       0.41      0.51      0.44       214


Final contrastive model F1: 0.4372
