In [31]:
import torch
import numpy as np
import pandas as pd
import os
import sys
import json
import warnings

sys.path.append(os.path.join(sys.path[0], '../'))
from models import manager_for_sagittarius
from evaluation import initialize_experiment
from TCGA import utils, filter_censored_patients, compute_non_stationary_genes
from config import TCGA_DATA_LOC

In [16]:
device = 'cuda:1'

# Load the data

In [17]:
initialize_experiment.initialize_random_seed(0)

highly_mutated = True  # check the gene mask that we should use
ct_vec, muts, ts, mask, _, ct_mapping, gmask, censoring = utils.load_all_mutation_data(
    remove_censored_data=False, restrict_to_highly_variable=highly_mutated)
N, T, M = muts.shape
C = len(ct_mapping)

ct_vec = ct_vec[:, 0].to(device)
muts = muts.to(device)
ts = ts.to(device)
mask = mask.to(device)
gmask = gmask.to(device)
censoring = censoring.to(device)

cleaner = filter_censored_patients.filter_cancer_type_time_series(
    muts, ts, censoring, mask, ct_vec, ct_mapping, load_from_file=True).to(device)
mask = mask * cleaner  # filter censored patients
maxT = torch.max(torch.masked_select(ts, mask.bool())).item()

0it [00:00, ?it/s]

...ACC had 0 patients
...BLCA had 391 patients


2it [00:00,  7.34it/s]

...BRCA had 973 patients


4it [00:01,  3.01it/s]

...CESC had 194 patients
...CHOL had 32 patients
...COAD had 366 patients


6it [00:01,  3.48it/s]

...COADREAD had 0 patients
...DLBC had 40 patients
...GBM had 277 patients


9it [00:02,  5.38it/s]

...GBMLGG had 0 patients
...HNSC had 509 patients


11it [00:02,  5.41it/s]

...KICH had 64 patients
...KIPAN had 0 patients
...KIRC had 436 patients


14it [00:02,  6.87it/s]

...KIRP had 0 patients
...LGG had 510 patients


17it [00:03,  6.49it/s]

...LIHC had 198 patients
...LUAD had 471 patients


18it [00:03,  5.15it/s]

...LUSC had 173 patients
...OV had 459 patients


20it [00:03,  5.11it/s]

...PAAD had 145 patients
...PCPG had 0 patients
...PRAD had 330 patients


24it [00:04,  6.88it/s]

...READ had 121 patients
...SARC had 244 patients
...STES had 285 patients


26it [00:04,  6.50it/s]

...TGCT had 130 patients
...THCA had 401 patients


31it [00:05,  6.10it/s]


...UCEC had 248 patients
...UCS had 0 patients
...UVM had 68 patients


# Load the model

In [18]:
def load_config_file():
    with open('model_config_files/Sagittarius_config.json', 'r') as f:
        return json.load(f)

In [19]:
initialize_experiment.initialize_random_seed(0)

sagittairus_manager = manager_for_sagittarius.Sagittarius_Manager(
    M, 1, [C], **load_config_file(), minT=0, maxT=maxT, device=device, train_transfer=False, 
    rec_loss='bce', batch_size=2)

sagittairus_manager.train_model(
    muts, ts, [ct_vec], mask, reload=True, mfile='trained_models/full_TCGA_model.pth')

# Start the extrapolation

In [20]:
survival_times_to_simulate = torch.tensor(np.arange(100, 300, 10)).to(device)
gen_k = 10

In [24]:
sim_cancer_types = []
sim_survival_times = []
sim_expr = []

initialize_experiment.initialize_random_seed(0)

for i in range(N):
    sim_cancer_types.extend([ct_mapping[ct_vec[i].item()] for _ in range(len(survival_times))])
    sim_survival_times.append(survival_times_to_simulate.detach().cpu().numpy())
    
    gen, _, _ = sagittairus_manager.model.generate(
        muts[i].unsqueeze(0).float(), ts[i].unsqueeze(0).float(), survival_times_to_simulate.unsqueeze(0).float(),
        [torch.stack([ct_vec[i] for _ in range(T)]).unsqueeze(0)],
        [torch.stack([ct_vec[i] for _ in range(len(survival_times_to_simulate))]).unsqueeze(0)],
        mask[i].unsqueeze(0), k=gen_k)
    sim_expr.append(gen[0])
    
sim_survival_times = np.concatenate(sim_survival_times)
sim_expr = torch.stack(sim_expr).view(-1, 1000).detach().cpu().numpy()

  input = module(input)


# Create .h5ad file for results

In [41]:
x = sim_expr
obsm = {'cancer type': np.asarray([sim_cancer_types]),
        'survival time': np.asarray([sim_survival_times])}

with open(TCGA_DATA_LOC + 'geneNum_to_idx_mapping.txt', 'r') as f:
    name2idx = json.load(f)
k = 0
for base in name2idx:
    if k > 5:
        k += 1
        break
    k += 1
idx2name = {idx: name for name, idx in name2idx.items()}

mapping = {}
with open(TCGA_DATA_LOC + 'name_mapping.txt', 'r') as f:
    for idx, line in enumerate(f.readlines()):
        if idx == 0:
            continue  # header line!
        line_parts = line.split('\t')
        if len(line_parts) < 5:
            continue
        # line_parts[0] = approved symbol
        # line_parts[3] = NCBI gene id
        if line_parts[3] == '':  # didn't have it for this gene
            continue
        mapping[line_parts[3]] = line_parts[0]
        
remaining_gene_by_idx = {}  # gene -> idx
gene_listing = []
for m in range(len(gmask)):
    if gmask[m] == 0:
        continue  # we didn't keep it!
    ncbi_id = idx2name[m]
    if ncbi_id not in mapping:
        warnings.warn('\tNo NCBI entry for {}; depricated id?'.format(ncbi_id))
        gene_name = ncbi_id
    else:
        gene_name = mapping[ncbi_id]
    remaining_gene_by_idx[gene_name] = np.count_nonzero(gmask[:m].detach().cpu().numpy())
    gene_listing.append(gene_name)
idx_to_gene_mapping = {remaining_gene_by_idx[gene]: gene for gene in remaining_gene_by_idx}

genes_included = [idx_to_gene_mapping[g] for g in range(1000)]
var = pd.DataFrame.from_dict({'gene': genes_included})



In [42]:
if not os.path.exists('../simulated_datasets/'):
    os.makedirs('../simulated_datasets/')

In [43]:
import anndata

adata = anndata.AnnData(X=x, obsm=obsm, var=var)
adata.write('../simulated_datasets/simulated_TCGA.h5ad')



ValueError: Value passed for key 'cancer type' is of incorrect shape. Values of obsm must match dimensions (0,) of parent. Value had shape (1, 480) while it should have had (480,).