In [2]:
import torch
import numpy as np
import pandas as pd
import os
import sys
import json

sys.path.append(os.path.join(sys.path[0], '../'))
from models import manager_for_sagittarius
from evaluation import initialize_experiment
from EvoDevo import utils
from config import EVO_DEVO_DATA_LOC

In [3]:
device = 'cpu'

# Load the data

In [4]:
initialize_experiment.initialize_random_seed(0)

species = sorted(['Chicken', 'Rat', 'Mouse', 'Rabbit', 'Opossum', 'RhesusMacaque', 'Human'])
organs = sorted(['Brain', 'Cerebellum', 'Liver', 'Heart', 'Kidney', 'Ovary', 'Testis'])
S = len(species)
O = len(organs)
spec_vec_long, org_vec_long, expr_vec, ts_vec, mask_vec = utils.load_all_data(
    device=device, verbose=False)
_, _, M = expr_vec.shape

# shuffle the data!
expr_vec, spec_vec_long, org_vec_long, ts_vec, mask_vec = utils.shuffle_data(
    expr_vec, spec_vec_long, org_vec_long, ts_vec, mask_vec)

non_stationary_mask = np.loadtxt('trained_models/Sagittarius_paper_genemask.txt')
non_stationary_mask = torch.tensor(non_stationary_mask).to(device)
expr_vec = torch.masked_select(expr_vec, non_stationary_mask.view(1, 1, -1).bool()).view(
    expr_vec.shape[0], expr_vec.shape[1], -1)
N, T, M_new = expr_vec.shape
spec_vec_long = spec_vec_long[:, 0].to(device)  # N
org_vec_long = org_vec_long[:, 0].to(device)  # N
expr_vec = expr_vec.to(device)
ts_vec = ts_vec.to(device)
mask_vec = mask_vec.to(device)

FileNotFoundError: [Errno 2] No such file or directory: 'EvoDevo/dataset/attribute.txt'

# Load the model

In [4]:
def load_config_file():
    with open('model_config_files/Sagittarius_config.json', 'r') as f:
        return json.load(f)

In [5]:
# Now, conduct extrapolation experiment task
initialize_experiment.initialize_random_seed(0)

# Train the model
sagittarius_manager = manager_for_sagittarius.Sagittarius_Manager(
    M_new, 2, [S, O], **load_config_file(), minT=0, maxT=T,
    device=device, train_transfer=False)
sagittarius_manager.train_model(
    expr_vec, ts_vec, [spec_vec_long, org_vec_long], mask_vec,
    reload=True, mfile='trained_models/Sagittarius_full_dataset_model.pth')

# Start the extrapolation

In [6]:
ts_to_simulate = torch.tensor(np.arange(0, 30, 0.05)).to(device)
gen_k = 10  # take 10 samples from latent space for each generation

In [7]:
stacked_species = torch.stack([spec_vec_long for _ in range(T)], dim=1)
stacked_organs = torch.stack([org_vec_long for _ in range(T)], dim=1)

gen_species = torch.stack([spec_vec_long for _ in range(len(ts_to_simulate))], dim=1)
gen_organs = torch.stack([org_vec_long for _ in range(len(ts_to_simulate))], dim=1)

In [8]:
sim_labels = []
sim_expr = []

initialize_experiment.initialize_random_seed(0)

for i in range(N):
    sim_labels.extend([np.asarray([
        species[spec_vec_long[i].item()], organs[org_vec_long[i].item()], t.item()])
        for t in ts_to_simulate])
    
    gen, _, _ = sagittarius_manager.model.generate(
        expr_vec[i].unsqueeze(0), ts_vec[i].unsqueeze(0), ts_to_simulate.unsqueeze(0),
        [stacked_species[i].unsqueeze(0), stacked_organs[i].unsqueeze(0)], 
        [gen_species[i].unsqueeze(0), gen_organs[i].unsqueeze(0)],
        mask_vec[i].unsqueeze(0), k=gen_k)
    sim_expr.append(gen[0])

  input = module(input)


In [9]:
sim_expr = torch.stack(sim_expr)
sim_labels = np.stack(sim_labels)

sim_labels.shape, sim_expr.shape

((28800, 3), torch.Size([48, 600, 4533]))

# Create .h5ad result file

In [10]:
x = sim_expr.view(-1, M_new).detach().cpu().numpy()  # NT x M_new
obsm = {'species': np.asarray([sim_labels[i][0] for i in range(len(x))]),
        'organs': np.asarray([sim_labels[i][1] for i in range(len(x))]),
        'time': np.asarray([sim_labels[i][2] for i in range(len(x))])}
gene_mapping = pd.read_csv(
    EVO_DEVO_DATA_LOC + 'gene_orderings/idx_to_gene.txt',
    names=['gene'], index_col=0)
name_mapping = {}
with open('/data/addiewc/scRNA_disentangle/unseen_attribute_generation/lit_search/MOD/dated_models/figure_processing/name_mapping.txt', 'r') as f:
    for idx, line in enumerate(f.readlines()):
        if idx == 0:  # header row
            continue
        names = line.strip().split('\t')
        if len(names) < 5:  # missing ensembl id
            continue
        ensembl_id = names[4]
        symbol = names[0]
        name_mapping[ensembl_id] = symbol
        
genes_included = []
for idx in range(M):
    if non_stationary_mask[idx] == 0:  # not included
        continue
    genes_included.append(name_mapping[gene_mapping.loc[idx]['gene']])
var = pd.DataFrame.from_dict({'gene': genes_included})

In [11]:
if not os.path.exists('../simulated_datasets/'):
    os.makedirs('../simulated_datasets/')

In [12]:
import anndata

adata = anndata.AnnData(X=x, obsm=obsm, var=var)
adata.write('../simulated_datasets/simulated_EvoDevo.h5ad')

