In [1]:
import os
import anndata as ad
import pandas as pd
import numpy as np
from omicsdgd import DGD

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_seed = 0
fraction_unpaired = 0.1

In [3]:
###
# load data
###
data_name = "human_bonemarrow"
adata = ad.read_h5ad("../../data/" + data_name + ".h5ad")
adata.X = adata.layers["counts"] # I seem to have to do it again

# train-validation-test split for reproducibility
# best provided as list [[train_indices], [validation_indices]]
train_val_split = [
    list(np.where(adata.obs["train_val_test"] == "train")[0]),
    list(np.where(adata.obs["train_val_test"] == "validation")[0]),
]

valset = adata[adata.obs["train_val_test"] == "validation"].copy()
valset.obs["modality"] = "paired"
testset = adata[adata.obs["train_val_test"] == "test"].copy()
testset.obs["modality"] = "paired"
#adata = None

In [4]:
testset

AnnData object with n_obs × n_vars = 6925 × 129921
    obs: 'GEX_pct_counts_mt', 'GEX_n_counts', 'GEX_n_genes', 'GEX_size_factors', 'GEX_phase', 'ATAC_nCount_peaks', 'ATAC_atac_fragments', 'ATAC_reads_in_peaks_frac', 'ATAC_blacklist_fraction', 'ATAC_nucleosome_signal', 'cell_type', 'batch', 'ATAC_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'train_val_test', 'observable', 'covariate_Site', 'modality'
    var: 'feature_types', 'gene_id', 'modality'
    uns: 'ATAC_gene_activity_var_names', 'dataset_id', 'genome', 'organism'
    obsm: 'ATAC_gene_activity', 'ATAC_lsi_full', 'ATAC_lsi_red', 'ATAC_umap', 'GEX_X_pca', 'GEX_X_umap'
    layers: 'counts'

In [4]:
df_unpaired = pd.read_csv('../../data/'+data_name+'_unpairing.csv')

In [5]:
mod_1_indices = df_unpaired[
    (df_unpaired["fraction_unpaired"] == fraction_unpaired) & (df_unpaired["modality"] == "rna")
]["sample_idx"].values
mod_2_indices = df_unpaired[
    (df_unpaired["fraction_unpaired"] == fraction_unpaired) & (df_unpaired["modality"] == "atac")
]["sample_idx"].values
remaining_indices = df_unpaired[
    (df_unpaired["fraction_unpaired"] == fraction_unpaired) & (df_unpaired["modality"] == "paired")
]["sample_idx"].values

var_before = adata.var.copy()

In [6]:
#adata_unpaired = ad.read("../../data/"+data_name+"_unpaired-"+str(fraction_unpaired)+".h5ad")

adata_rna = adata[mod_1_indices, adata.var["feature_types"] == "GEX"].copy()
adata_rna.obs["modality"] = "GEX"
print("copied rna")
adata_atac = adata[mod_2_indices, adata.var["feature_types"] == "ATAC"].copy()
adata_atac.obs["modality"] = "ATAC"
print("copied atac")
adata_multi = adata[remaining_indices, :].copy()
adata_multi.obs["modality"] = "paired"
print("copied rest")
adata = None
print("freed some memory")
adata_unpaired = ad.concat([adata_multi, adata_rna, adata_atac], join="outer")
print("organized data")

adata_unpaired

copied rna
copied atac
copied rest
freed some memory
organized data


AnnData object with n_obs × n_vars = 56714 × 129921
    obs: 'GEX_pct_counts_mt', 'GEX_n_counts', 'GEX_n_genes', 'GEX_size_factors', 'GEX_phase', 'ATAC_nCount_peaks', 'ATAC_atac_fragments', 'ATAC_reads_in_peaks_frac', 'ATAC_blacklist_fraction', 'ATAC_nucleosome_signal', 'cell_type', 'batch', 'ATAC_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'train_val_test', 'observable', 'covariate_Site', 'modality'
    obsm: 'ATAC_gene_activity', 'ATAC_lsi_full', 'ATAC_lsi_red', 'ATAC_umap', 'GEX_X_pca', 'GEX_X_umap'
    layers: 'counts'

In [10]:
adata_rna, adata_atac, adata_multi = None, None, None
adata_unpaired.var = var_before

In [7]:
adata_rna, adata_atac, adata_multi = None, None, None
print("freed memory")
#adata = adata_unpaired.concatenate(valset)
adata = ad.concat([adata_unpaired, valset], join="inner")
print("finished data")
adata.var = var_before

freed memory


: 

In [9]:
train_val_split = [
    list(np.where(adata.obs["train_val_test"] == "train")[0]),
    list(np.where(adata.obs["train_val_test"] == "validation")[0]),
]

AttributeError: 'NoneType' object has no attribute 'obs'

In [11]:
model = DGD.load(
    #data=adata[train_val_split[0]], 
    data=adata_unpaired,
    save_dir="../results/trained_models/" + data_name + "/", 
    model_name=data_name + "_l20_h2-3_rs" + str(random_seed)+"_mosaic"+str(fraction_unpaired)+"percent"
)

  [AnnData(sparse.csr_matrix(a.shape), obs=a.obs) for a in all_adatas],
  [AnnData(sparse.csr_matrix(a.shape), obs=a.obs) for a in all_adatas],



        Gaussian_mix_compture:
            Dimensionality: 2
            Number of components: 4
        
#######################
Training status
#######################
True


In [12]:
import torch
import numpy as np
from omicsdgd.latent import RepresentationLayer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from omicsdgd.functions._predict import prepare_potential_reps, find_new_component, reshape_scaling_factor

def learn_new_representations(
    gmm,
    decoder,
    data_loader,
    n_samples_new,
    correction_model=None,
    n_epochs=10,
    lrs=[0.01, 0.01],
    resampling_type="mean",
    resampling_samples=1,
    include_correction_error=True,
    indices_of_new_distribution=None,
    start_from_zero=[False,False],
    init_covariate_supervised=None,
    supervised=False,
    cov_beta=1
):
    """
    this function creates samples from the trained GMM for each new point,
    returns the best representation for each sample,
    and trains the representations with all remaining parameters fixed

    gmm: the trained GMM
    decoder: the trained decoder
    data_loader: the data loader for the data to be predicted
    n_samples_new: the number of samples in the new data
    correction_model: the trained correction model (if available)
    n_epochs: the number of epochs to train the representations
    lrs: the learning rates for the representations
    resampling_type: the type of resampling to use for the GMM (can be mean or sample)
    resampling_samples: the number of samples to draw from the GMM for each new point (is ignored for resampling_type='mean')
    """

    # check if there are unknown distributions in the data
    #if gmm.n_mix_comp < len(data_loader.dataset.meta.unique()):
    #    print("WARNING: there are unknown distributions in the data\nWill learn extra components in the GMM")
    #    print(gmm.n_mix_comp, len(data_loader.dataset.meta.unique())) # apparently site1 had ha whole cell type
    correction_hook = False
    #"""
    if correction_model is not None:
        if indices_of_new_distribution is not None:
            if correction_model.n_mix_comp < data_loader.dataset.correction_classes:
                print("WARNING: there are unknown distributions in the data\nWill learn extra components in the batch GMM")
                n_correction_classes_old = correction_model.n_mix_comp
                correction_model = find_new_component(
                    data_loader,
                    decoder,
                    correction_model,
                    indices_of_new_distribution,
                    other_gmm=gmm)
                correction_hook = True
    #"""

    # make temporary representations with samples from each component per data point
    if correction_model is not None:
        if not start_from_zero[1]:
            if init_covariate_supervised is not None:
                potential_reps = prepare_potential_reps(
                    [gmm.sample_new_points(resampling_type, resampling_samples), torch.zeros((resampling_samples, correction_model.dim))]
                )
            else:
                potential_reps = prepare_potential_reps(
                    [
                        gmm.sample_new_points(resampling_type, resampling_samples),
                        correction_model.sample_new_points(resampling_type, resampling_samples),
                    ]
                )
        else:
            potential_reps = prepare_potential_reps(
                [gmm.sample_new_points(resampling_type, resampling_samples), torch.zeros((resampling_samples, correction_model.dim))]
            )
    else:
        potential_reps = prepare_potential_reps(
            [gmm.sample_new_points(resampling_type, resampling_samples)]
        )

    #print("making potential reps")
    print("   all potential reps: ", potential_reps.shape)

    decoder.eval()

    ############################
    # first match potential reps to samples
    ############################
    print("calculating losses for each new sample and potential reps")
    # creating a storage tensor into which the best reps are copied
    rep_init_values = torch.zeros((n_samples_new, potential_reps.shape[-1]))
    print("   rep_init_values: ", rep_init_values.shape)
    # compute predictions for all potential reps
    predictions = decoder(potential_reps.to(device))
    # go through data loader to calculate losses batch-wise
    for x, lib, i in data_loader:
        x = x.unsqueeze(1).to(device)
        lib = lib.to(device)
        if data_loader.dataset.modality_switch is not None:
            recon_loss_x = decoder.loss(
                [predictions[comp].unsqueeze(0) for comp in range(len(predictions))],
                [x[:, :, : data_loader.dataset.modality_switch], x[:, :, data_loader.dataset.modality_switch :]],
                scale=[reshape_scaling_factor(lib[:, xxx], 3) for xxx in range(decoder.n_out_groups)],
                reduction="sample",
                mask=data_loader.dataset.get_mask(i)
            )
        else:
            recon_loss_x = decoder.loss(
                [predictions[0].unsqueeze(0)],
                [x],
                scale=[reshape_scaling_factor(lib[:, 0], 3)],
                reduction="sample",
                mask=data_loader.dataset.get_mask(i)
            )
        best_fit_ids = torch.argmin(recon_loss_x, dim=-1).detach().cpu()
        rep_init_values[i, :] = potential_reps.clone()[best_fit_ids, :]
    # print the counts of how often each component has been chosen
    print("   ", rep_init_values.mean(0).shape, rep_init_values.mean(-1).shape)
    print("   counts of how often each component has been chosen: ", np.unique(rep_init_values.mean(-1).numpy(), return_counts=True))

    ############################
    # create new initial representation
    ############################
    # create a new representation from the best components
    new_rep = RepresentationLayer(n_rep=gmm.dim, n_sample=n_samples_new, value_init=rep_init_values[:, : gmm.dim]).to(
        device
    )
    newrep_optimizer = torch.optim.Adam(new_rep.parameters(), lr=lrs[0], weight_decay=1e-4, betas=(0.5, 0.7))
    test_correction_rep = None
    if correction_model is not None:
        if (not start_from_zero[1]) and (init_covariate_supervised is None):
            test_correction_rep = RepresentationLayer(
                n_rep=2, n_sample=n_samples_new, value_init=rep_init_values[:, gmm.dim :]
            ).to(device)
        elif init_covariate_supervised is not None:
            if indices_of_new_distribution is not None:
                init_covariate_supervised = np.array(init_covariate_supervised)
                if (len(np.unique(init_covariate_supervised)) > correction_model.n_mix_comp):
                    raise NotImplementedError("I can currently only handle one new covariate class")
                #print(init_covariate_supervised[indices_of_new_distribution])
                # change the indices of the new distribution to the last component
                # count and print the number of unique values in the init covariate supervised, like with value_counts
                init_covariate_supervised[indices_of_new_distribution] = correction_model.n_mix_comp - 1
                #print(init_covariate_supervised[indices_of_new_distribution])
            with torch.no_grad():
                cov_means = torch.zeros((len(data_loader.dataset), correction_model.dim))
                for i in range(len(data_loader.dataset)):
                    cov_means[i,:] = correction_model.mean[init_covariate_supervised[i],:].clone().detach().cpu()
            test_correction_rep = RepresentationLayer(
                n_rep=2, n_sample=n_samples_new, value_init=cov_means
            ).to(device)
        else:
            test_correction_rep = RepresentationLayer(
                n_rep=2, n_sample=n_samples_new, value_init="zero"
            ).to(device)
        correction_rep_optim = torch.optim.Adam(
            test_correction_rep.parameters(), lr=lrs[0], weight_decay=1e-4, betas=(0.5, 0.7)
        )
        if correction_hook:
            correction_model_optim = torch.optim.Adam(
            correction_model.parameters(), lr=lrs[1], weight_decay=0, betas=(0.5, 0.7)
            )

    rep_init_values = None

    #####################
    # training reps (only)
    #####################
    """
    supervised_warmup = 20
    if supervised and (init_covariate_supervised is not None):
        # fine-tune the relevant neurons in the first layer of the decoder
        for param in decoder.parameters():
            param.requires_grad = False
        # I want to finetune the last neurons of the first layer, corresponding to the covariate model (correction_rep input)
        indices_covariate_input = np.arange(new_rep.n_rep, new_rep.n_rep + test_correction_rep.n_rep)
        # print the old weights
        #print("old weights: ", decoder.main[0].weight[:, indices_covariate_input])
        decoder.main[0].weight.requires_grad = True
        decoder.main[0].bias.requires_grad = True # did not exist for first version
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lrs[0]/100, weight_decay=1e-4, betas=(0.5,0.7)) # was /100 for first version
    """
    print("training selected reps for ", n_epochs, " epochs")
    for epoch in range(n_epochs):
        newrep_optimizer.zero_grad()
        if correction_model is not None:
            correction_rep_optim.zero_grad()
        corr_loss = 0
        e_loss = 0
        recon_loss = 0
        for x, lib, i in data_loader:
            if correction_hook:
                correction_model_optim.zero_grad()
                #if supervised:
                #    decoder_optimizer.zero_grad()
            x = x.to(device)
            lib = lib.to(device)
            if correction_model is not None:
                z = new_rep(i)
                z_correction = test_correction_rep(i)
                y = decoder(torch.cat((z, z_correction), dim=1))
            else:
                z = new_rep(i)
                y = decoder(z)
            # compute losses
            if data_loader.dataset.modality_switch is not None:
                recon_loss_x = decoder.loss(
                    y,
                    [x[:, : data_loader.dataset.modality_switch], x[:, data_loader.dataset.modality_switch :]],
                    scale=[lib[:, xxx].unsqueeze(1) for xxx in range(decoder.n_out_groups)],
                    mask=data_loader.dataset.get_mask(i)
                )
            else:
                recon_loss_x = decoder.loss(
                    y, [x], scale=[lib[:, xxx].unsqueeze(1) for xxx in range(decoder.n_out_groups)],
                    mask=data_loader.dataset.get_mask(i)
                )
            # gmm_error = gmm.forward_split(gmm,z).sum()
            gmm_error = gmm(z).sum()
            correction_error = torch.zeros(1).to(device)
            if correction_model is not None:
                if supervised:
                    # get the component whose mean was used to initialize each sample
                    supervision_idx = init_covariate_supervised[i]
                    correction_error += correction_model(z_correction, supervision_idx).sum()
                else:
                    correction_error += correction_model(z_correction).sum()
                corr_loss += correction_error.item()
            if include_correction_error:
                loss = recon_loss_x.clone() + gmm_error.clone() + correction_error.clone() * cov_beta
            else:
                loss = recon_loss_x.clone() + gmm_error.clone()
            loss.backward()
            if correction_hook:
                correction_model.mean.grad[:n_correction_classes_old,:] = 0
                correction_model.neglogvar.grad[:n_correction_classes_old,:] = 0
                correction_model.weight.grad[:n_correction_classes_old] = 0
                correction_model_optim.step()
                #if supervised and (epoch > supervised_warmup):
                #    decoder.main[0].weight.grad[:, :new_rep.n_rep] = 0
                #    decoder_optimizer.step()
            e_loss += loss.item()
            recon_loss += recon_loss_x.clone().item()

        newrep_optimizer.step()
        if correction_model is not None:
            correction_rep_optim.step()
        e_loss /= (len(data_loader.dataset)*data_loader.dataset.n_features)
        recon_loss /= (len(data_loader.dataset)*data_loader.dataset.n_features)
        if correction_model is not None:
            corr_loss /= (len(data_loader.dataset)*data_loader.dataset.n_features)
            print("epoch: ", epoch, " loss: ", e_loss, " recon: ", recon_loss, " corr: ", corr_loss)
        else:
            print("epoch: ", epoch, " loss: ", e_loss, " recon: ", recon_loss)
    
    #with torch.no_grad():
    #    print("new weights: ", decoder.main[0].weight[:, indices_covariate_input])
    
    if correction_hook:
        if supervised:
            return decoder, new_rep, test_correction_rep, correction_model
        return None, new_rep, test_correction_rep, correction_model
    else:
        return None, new_rep, test_correction_rep, None

In [13]:
original_name = model._model_name
# change the model name (because we did inference once for 10 epochs and once for 50)
model._model_name = original_name + "_test10e"
model.predict_new(testset)
print("   test set inferred")

check: modality names are  ['GEX', 'ATAC']
   all potential reps:  torch.Size([88, 22])
calculating losses for each new sample and potential reps
   rep_init_values:  torch.Size([6925, 22])
    torch.Size([22]) torch.Size([6925])
   counts of how often each component has been chosen:  (array([-0.08899279, -0.02229357, -0.01778638, -0.01645249, -0.0154126 ,
       -0.01352611, -0.0122846 ,  0.05693668,  0.05827057,  0.05931045,
        0.06119695], dtype=float32), array([  60,    2,  942, 5320,  534,    1,    6,   22,   21,   16,    1]))
training selected reps for  10  epochs
epoch:  0  loss:  0.21522441671510625  recon:  0.21107790882056984  corr:  8.265609951379e-05
epoch:  1  loss:  0.21155740183280497  recon:  0.20739153268633087  corr:  8.280899092986277e-05
epoch:  2  loss:  0.20821144059857313  recon:  0.204012878091121  corr:  8.311597839598633e-05
epoch:  3  loss:  0.2052402194660532  recon:  0.2009864627155958  corr:  8.363679819810628e-05
epoch:  4  loss:  0.20266228059737051

In [13]:
original_name = model._model_name
# change the model name (because we did inference once for 10 epochs and once for 50)
model._model_name = original_name + "_test10e_old"
model.predict_new(testset)
print("   test set inferred")

making potential reps
   all potential reps:  torch.Size([88, 22])
calculating losses for each new sample and potential reps
training selected reps for  10  epochs
epoch:  0  loss:  0.30016952074764014
epoch:  1  loss:  0.2942641044195227
epoch:  2  loss:  0.28908638261415565
epoch:  3  loss:  0.28476411996979445
epoch:  4  loss:  0.2813018356517514
epoch:  5  loss:  0.27835773076151776
epoch:  6  loss:  0.27579175957580665
epoch:  7  loss:  0.27352224981477913
epoch:  8  loss:  0.27152744189144434
epoch:  9  loss:  0.26980502880297963
   test set inferred
