In [3]:
import read_folders as rf
import numpy as np
import utility_functions as uf
import random
import read_specific as rs
import torch
import cox_loss as cl
import torch.nn as nn
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data import DataLoader
import torch
from torch_geometric.nn import GraphConv, global_add_pool
from torch.nn import Linear
from torch import optim
from cox_loss import cox_loss_effron
import os
import re
import pandas as pd
import download_study as ds
from torch_geometric.data import Data, Batch
import sys

In [4]:
path, sources, urls = ds.download_study(name = 'msk_immuno_2019')

Using existing file tmb_mskcc_2018.tar.gz
Extracting temp/tmb_mskcc_2018.tar.gz
Extracting temp/tmb_mskcc_2018.tar.gz


In [5]:
data_dict = rf.read_files(path[0])

0 Samples not in mutation data
Genes missing from length dict: []
0 Samples not in SV data


In [6]:
def print_shapes(data, prefix=""):
    for key, value in data.items():
        full_key = f"{prefix}.{key}" if prefix else key
        if isinstance(value, list):
            continue
        if isinstance(value, dict):
            print_shapes(value, prefix=full_key)  # recursive call
        elif isinstance(value, np.ndarray):
            print(f"{full_key} shape: {value.shape}")
        else:
            print(f"{full_key} shape: Not a NumPy array, type is {type(value).__name__}")

print_shapes(data_dict)

max_muts shape: Not a NumPy array, type is int64
patient_array shape: (1661, 4)
os_array shape: (1661, 2)
mutations.mut_pos shape: (1661, 554, 12)
mutations.var_type shape: (1661, 554, 5, 12)
mutations.aa_sub shape: (1661, 554, 48, 12)
mutations.ns shape: (1661, 554, 12)
sv.chrom shape: (1661, 554, 46, 12)
sv.var_class shape: (1661, 554, 6, 12)


In [7]:
mut_pos = data_dict['mutations']['mut_pos']      # Shape (1661, 554, 12) -> Needs reshape
var_type = data_dict['mutations']['var_type']   # Shape (1661, 554, 5, 12)
aa_sub = data_dict['mutations']['aa_sub']      # Shape (1661, 554, 48, 12) 24 aa (ref and alt)
ns = data_dict['mutations']['ns']              # Shape (1661, 554, 12) -> Needs reshape
chrom = data_dict['sv']['chrom']              # Shape (1661, 554, 46, 12)  46 = 23 * 2 (ref and alt chromosome)
var_class = data_dict['sv']['var_class']       # Shape (1661, 554, 6, 12)

In [8]:
def merge_last_two_dims(x):
    # x.shape == (1661, 554, D, 12)
    n0, n1, D, C = x.shape
    return x.reshape(n0, n1, D*C)

aa_sub_flat    = merge_last_two_dims(aa_sub)     # → (1661, 554, 48*12  = 576)
chrom_flat     = merge_last_two_dims(chrom)      # → (1661, 554, 46*12  = 552)
var_class_flat = merge_last_two_dims(var_class)  # → (1661, 554,  6*12  =  72)
var_type_flat = merge_last_two_dims(var_type)

In [9]:
chrom_flat.shape

(1661, 554, 552)

In [10]:
clinical_data = data_dict['patient_array']
gene_list = data_dict['gene_list']
osurv_data = data_dict['os_array']

In [11]:
#Create a list of arrays in the desired concatenation order
arrays_to_concat = [
    mut_pos,  # Size 1 along axis 2
    var_type_flat,    # Size 5 along axis 2
    aa_sub_flat,      # Size 48 along axis 2
    ns,       # Size 1 along axis 2
    chrom_flat,       # Size 46 along axis 2
    var_class_flat    # Size 6 along axis 2
]
omics = np.concatenate(arrays_to_concat, axis=2)

In [13]:
print(omics.shape)

(1661, 554, 1284)


In [14]:
genes_to_keep_mask = np.any(omics != 0, axis=(0, 2))
print(len(genes_to_keep_mask))

554


In [15]:
# We will set a seed and split into training (80%), validation (10%) and testing (10%)
sample_list = data_dict['sample_list'] # get samples
random.seed(3)
sample_index = [i for i in range(len(sample_list))]
random.shuffle(sample_index)

In [16]:
ntrain = int(0.8*len(sample_list))
nval = int(0.1*len(sample_list))

train_set = sample_index[0:ntrain]
val_set = sample_index[ntrain:(ntrain+nval)]
test_set = sample_index[(ntrain+nval):]

In [17]:
omics_train = omics[train_set]
omics_test = omics[test_set]
omics_val = omics[val_set]

In [18]:
clin_train = clinical_data[train_set] 
clin_test = clinical_data[test_set]
clin_val = clinical_data[val_set]

In [19]:
osurv_train = osurv_data[train_set]
osurv_test = osurv_data[test_set]
osurv_val = osurv_data[val_set]

In [22]:
# Get the graph with the data. 
graph = rs.read_reactome(gene_list)
tokens = torch.tensor(np.arange(0,len(gene_list)))

edges_all = graph['edges_index_grn_act'] + graph['edges_index_grn_rep'] + graph['edges_index_ppi_act'] + graph['edges_index_ppi_inh'] + graph['edges_index_ppi_bin']
edges_index_torch = torch.tensor(edges_all)
edges_index_torch =  edges_index_torch.transpose(1,0)
tokens = torch.tensor(np.arange(0,len(gene_list)))

Using existing file FIsInGene_122921_with_annotations.txt.zip
Extracting temp/FIsInGene_122921_with_annotations.txt.zip


In [87]:
class gin_omics(nn.Module):

    def __init__(self, feats_in, feats_out, edge_index, max_tokens):
        super(gin_omics, self).__init__()
        """
        feats_in: Number of gene-level input features
        feats_out: Number of output features per gene
        edge_index: (2, num_edges) tensor defining gene-gene interactions
        max_tokens: Total number of genes
        """
        # Build adjacency matrix for gene aggregation
        self.adj = torch.zeros((max_tokens, max_tokens), dtype=torch.float)
        self.adj[edge_index[0], edge_index[1]] = 1.0
        self.adj.fill_diagonal_(1.0) # maintain self loop
            
        self.lin_1 = Linear(feats_in, feats_out, bias=False)
        self.lin_2 = Linear(feats_out, feats_out, bias=False)
        self.gene_weights = nn.Parameter(torch.randn(max_tokens, feats_in))
        self.act1 = nn.LeakyReLU()
        self.act2 = nn.LeakyReLU()

    def forward(self, x):
        # x: (B, G, F)

        # 1) per-gene, per-feature-type weights of shape (G, F):
        w = self.gene_weights            
        w = w.unsqueeze(0)  # (1, G, F)
            
        # 2) apply weights to each occurrence:
        x = x * w  # stays (1, G, F)                     

        x2 = torch.matmul(x.transpose(1, 2), self.adj)  # dot product
        x = x + x2.transpose(1, 2)                     

        x = self.lin_1(x)     
        x = self.act1(x)
        x = self.lin_2(x)
        x = self.act2(x)
        
        return x
        

In [86]:
class Net_omics(torch.nn.Module):
    def __init__(self, features_omics, features_clin, dim, max_tokens, edge_index, output = 21):
        super(Net_omics, self).__init__()
        self.gin1 = gin_omics(features_omics, dim, edge_index, max_tokens)
        self.gin2 = gin_omics(dim, dim, edge_index, max_tokens)
        self.gin3 = gin_omics(dim, 1, edge_index, max_tokens)
        self.linout = Linear(max_tokens, output, bias = False)
        self.max_tokens = max_tokens
        self.features_omics = features_omics
        self.tokens = torch.tensor(np.arange(0,max_tokens))
        self.linclin = Linear(features_clin, 1)
        self.lin3 = Linear(output, 5)
        
    def forward(self, omics, clin):
        # get the weights for the connections through kd.
        x = self.gin1(omics)
        x = self.gin2(x)
        x = self.gin3(x)
        x = torch.flatten(x, 1)
        x = self.linout(x)
        x2 = self.linclin(clin)
        x1 = self.lin3(x) + x2
        return x1

In [98]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net_omics(features_omics=omics_train.shape[2], features_clin=clin_train.shape[-1], dim=50, max_tokens=len(gene_list), edge_index=edges_index_torch, output=2).to(device)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

In [1]:
# Now we can make a training function!
def train_block(model, data):
    model.train()
    loss_all = 0
    c_index = 0
    j = 0
    for batch in data:

        batch = uf.move_batch_to_device(batch=batch, device=device) # wrapper for moving batch to gpu
        j += 1

        # set the gradients in the optimizer to zero.
        optimizer.zero_grad()
        
        # run the model
        pred = model(batch.omics, batch.clin)
        
        # calculate Cox' partial likelihood and get the autograd output.
        loss = cox_loss_effron(batch.osurv, pred)
        loss.backward()
        
        # update the loss_all object
        loss_all += loss.item()
        
        # update parameters in model.
        optimizer.step()
        
        # calculate concordance index
        c_index += cl.concordance_index(batch.osurv, pred)

    return loss_all/j, c_index/j

# We will make a testing function too. 
def evaluate_model(model, data):
    """
    Evaluate the model on a validation or test set.

    Args:
        model (nn.Module): The trained model.
        loader (DataLoader): DataLoader for validation or test data.

    Returns:
        Tuple (avg_loss, avg_c_index): Mean Cox loss and concordance index over batches.
    """
    model.eval()
    loss_all = 0
    c_index = 0
    num_batches = 0

    with torch.no_grad():
        for batch in data:
            batch = uf.move_batch_to_device(batch, device=device)
            num_batches += 1
            pred = model(batch.omics, batch.clin)
            loss = cox_loss_effron(batch.osurv, pred)
            loss_all += loss.item()
            c_index += cl.concordance_index(batch.osurv, pred)

    return loss_all / num_batches, c_index / num_batches


In [92]:
class CoxBatchDataset(IterableDataset):
    def __init__(self, osurv, clin, omics, batch_size=10, shuffle=True):
        # Convert to tensors if needed
        self.osurv = osurv if isinstance(osurv, torch.Tensor) else torch.tensor(osurv, dtype=torch.float)
        self.clin = clin if isinstance(clin, torch.Tensor) else torch.tensor(clin, dtype=torch.float)
        self.omics = omics if isinstance(omics, torch.Tensor) else torch.tensor(omics, dtype=torch.float)

        if self.osurv.shape[0] < batch_size:
            raise ValueError("Dataset size smaller than batch size")

        self.deads = torch.where(self.osurv[:, 1] == 1)[0].tolist()
        self.censored = torch.where(self.osurv[:, 1] == 0)[0].tolist()
        if not self.deads:
            raise ValueError("No uncensored events found")

        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        indices = self.deads + self.censored
        if self.shuffle:
            random.shuffle(self.deads)
            random.shuffle(self.censored)
            indices = self.deads + self.censored
            random.shuffle(indices)

        ptr = 0
        n = len(indices)

        while ptr < n:
            remaining = n - ptr
            current_batch_size = min(self.batch_size, remaining)

            remaining_deads = sum(1 for i in indices[ptr:] if i in self.deads)
            ratio = remaining_deads / remaining if remaining > 0 else 0

            j_d = np.random.binomial(current_batch_size, ratio)
            j_d = max(1, j_d)

            batch_indices = []
            dead_count = 0
            censored_count = 0

            for idx in indices[ptr:]:
                if idx in self.deads and dead_count < j_d:
                    batch_indices.append(idx)
                    dead_count += 1
                elif idx in self.censored and censored_count < (current_batch_size - j_d):
                    batch_indices.append(idx)
                    censored_count += 1
                if len(batch_indices) == current_batch_size:
                    break

            ptr += len(batch_indices)

            batch_data = [
                Data(
                    osurv=self.osurv[i].unsqueeze(0),
                    clin=self.clin[i].unsqueeze(0),
                    omics=self.omics[i].unsqueeze(0)
                ) for i in batch_indices
            ]

            if batch_data:  # only yield if the list isn't empty
                yield Batch.from_data_list(batch_data)

In [93]:
train_data = CoxBatchDataset(osurv_train, clin_train, omics_train, batch_size=10, shuffle=True)
val_data = CoxBatchDataset(osurv_val, clin_val, omics_val, batch_size=10, shuffle=True)

In [95]:
from tqdm import tqdm  # optional, for a nice progress bar

# Training parameters
epochs = 500

# Metrics tracking
ci_val = []
ci_train = []
loss_val = []
loss_train = []

# Initial evaluation before training
vloss, vci = evaluate_model(model, val_data)
tloss, tci = evaluate_model(model, train_data)

ci_val.append(float(vci))
ci_train.append(float(tci))
loss_val.append(float(vloss))
loss_train.append(float(tloss))

print(f"[Init] Train CI: {tci:.4f}, Loss: {tloss:.4f} | Val CI: {vci:.4f}, Loss: {vloss:.4f}")

# Training loop
for epoch in tqdm(range(1, epochs + 1), desc="Training"):
    # Training step
    tloss, tci = train_block(model, train_data)
    ci_train.append(tci)
    loss_train.append(tloss)

    # Validation step
    vloss, vci = evaluate_model(model, val_data)
    ci_val.append(vci)
    loss_val.append(vloss)

    # Print progress
    print(f"Epoch {epoch:03d} | Train CI: {tci:.4f}, Loss: {tloss:.4f} | Val CI: {vci:.4f}, Loss: {vloss:.4f}")

    # Optional: save best model
    # if vci == max(ci_val):
    #     torch.save(model.state_dict(), "best_model.pt")


[Init] Train CI: nan, Loss: nan | Val CI: 0.0000, Loss: nan


Training:   0%|          | 0/500 [00:19<?, ?it/s]


KeyboardInterrupt: 

In [68]:
x = torch.rand(len(gene_list), omics_train.shape[1]) 
print(x.shape)

torch.Size([554, 554])
