In [None]:
%load_ext autoreload
%autoreload 2

import sys
repo_dir = '/home/labs/amit/noamsh/repos/MM_2023'
sys.path.append(repo_dir)

In [None]:
from pathlib import Path
from datetime import date
from omegaconf import OmegaConf

import pandas as pd
import scanpy as sc
import anndata as ad
import scvi
import torch

from train_scvi_model import train_scvi_model
from io_utils import generate_path_in_output_dir

In [None]:
config_path = Path(repo_dir, 'config.yaml')
conf = OmegaConf.load(config_path)

ts_iso = date.today().isoformat()
sc.set_figure_params(dpi=150, dpi_save=300)
figures_dir = Path(conf.outputs.output_dir, "figures", ts_iso)

In [None]:
# load_ts_iso = "2024-06-20" 
load_ts_iso = '2024-06-28'
data_version = "20240619"

In [None]:
adata_pc_with_annot_and_scvi_path = Path(conf.outputs.output_dir,
                                         f"adata_with_scvi_annot_pred_data_v_{data_version}_ts_{load_ts_iso}_only_pc_annotated_filtered.h5ad")
adata_only_pc = ad.read_h5ad(adata_pc_with_annot_and_scvi_path)
adata_only_pc

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
torch.set_float32_matmul_precision("high")

In [None]:
sc.pl.umap(adata_only_pc, color=["Method", 'pc_annotation'], ncols=2, wspace=0.3)
# sc.pl.umap(backup_adata, color=["Method", 'pc_annotation'], ncols=2, wspace=0.3)

In [None]:
backup_adata = adata_only_pc.copy()

### train LDAE

In [None]:
model_ldae = train_scvi_model(adata_only_pc, counts_layer = "counts", batch_key = "Method",
                         scvi_model_type=scvi.model.LinearSCVI, 
                         model_kwargs={"dropout_rate": 0.1},
                         trainer_kwargs={"batch_size": 512, 'max_epochs':200, 'plan_kwargs':{"lr": 5e-3}, 'check_val_every_n_epoch':10}
                        )

In [None]:
train_elbo = model_ldae.history["elbo_train"][1:]
test_elbo = model_ldae.history["elbo_validation"]

ax = train_elbo.plot()
test_elbo.plot(ax=ax)

In [None]:
Z_hat = model_ldae.get_latent_representation()
for i, z in enumerate(Z_hat.T):
    adata_only_pc.obs[f"Z_{i}"] = z

In [None]:
loadings = model_ldae.get_loadings()
print(
    "Top loadings by magnitude\n---------------------------------------------------------------------------------------"
)
for clmn_ in loadings:
    loading_ = loadings[clmn_].sort_values()
    fstr = clmn_ + ":\t"
    fstr += "\t".join([f"{i}, {loading_[i]:.2}" for i in loading_.head(5).index])
    fstr += "\n\t...\n\t"
    fstr += "\t".join([f"{i}, {loading_[i]:.2}" for i in loading_.tail(5).index])
    print(
        fstr
        + "\n---------------------------------------------------------------------------------------\n"
    )

In [None]:
zs = [f"Z_{i}" for i in range(model_ldae.n_latent)]
sc.pl.umap(adata_only_pc, color=zs, ncols=4)

### create new neigborhood graph

In [None]:
LDVI_LATENT_KEY = "X_LDVI"
LDVI_CLUSTERS_KEY = "leiden_LDVI"

adata_only_pc.obsm[LDVI_LATENT_KEY] = Z_hat
sc.pp.neighbors(adata_only_pc, use_rep=LDVI_LATENT_KEY, n_neighbors=20)
sc.tl.umap(adata_only_pc, min_dist=0.3)
# sc.tl.leiden(adata_only_pc, key_added=LDVI_CLUSTERS_KEY, resolution=1.2)

In [None]:
sc.pl.umap(adata_only_pc, color=["Method", 'pc_annotation'], ncols=2, wspace=0.3)

In [None]:
sc.pl.umap(adata_only_pc, color=zs, ncols=4)

### re-train again SCVI on all PC

In [None]:
model_scvi = train_scvi_model(adata_only_pc, counts_layer = "counts", batch_key = "Method",
                         scvi_model_type=scvi.model.SCVI, 
                         model_kwargs={"dropout_rate": 0.1},
                         trainer_kwargs={"batch_size": 512, 'max_epochs':250, 'plan_kwargs':{"lr": 5e-3}, 'check_val_every_n_epoch':10}
                        )

train_elbo = model_scvi.history["elbo_train"][1:]
test_elbo = model_scvi.history["elbo_validation"]

ax = train_elbo.plot()
test_elbo.plot(ax=ax)

In [None]:
scvi_latent_key = "X_SCVI"
adata_only_pc.obsm[scvi_latent_key] = model_scvi.get_latent_representation()
sc.pp.neighbors(adata_only_pc, use_rep=scvi_latent_key, n_neighbors=20)
sc.tl.umap(adata_only_pc, min_dist=0.3)

In [None]:
sc.pl.umap(adata_only_pc, color=["Method", "pc_annotation"])

### train on only PC var genes

In [None]:
adata_only_pc.layers["counts"].shape

In [None]:
from io_utils import generate_path_in_output_dir
from sc_classification.var_genes import normalize_and_choose_genes

In [None]:
adata_path = generate_path_in_output_dir(conf, conf.outputs.processed_adata_file_name, with_version=data_version)
adata_pc_new_vars = ad.read_h5ad(adata_path)
adata_pc_new_vars = adata_pc_new_vars[adata_only_pc.obs_names,:]
adata_pc_new_vars

In [None]:
adata_pc_new_vars_mars = normalize_and_choose_genes(adata_pc_new_vars[adata_pc_new_vars.obs["Method"]=="MARS"], conf, n_top_genes=2000)
adata_pc_new_vars_SPID = normalize_and_choose_genes(adata_pc_new_vars[adata_pc_new_vars.obs["Method"]=="SPID"], conf, n_top_genes=2000)

In [None]:
print(len(set(adata_pc_new_vars_SPID.var_names)), 
      len(set(adata_pc_new_vars_mars.var_names)),
      len(set(adata_pc_new_vars_SPID.var_names).intersection(set(adata_pc_new_vars_mars.var_names)))
     )

In [None]:
MARS_var_genes_path = '/home/labs/amit/noamsh/data/mm_2023/feats/pc_mars_genes.csv'
SPID_var_genes_path = '/home/labs/amit/noamsh/data/mm_2023/feats/pc_spid_genes.csv'

MARS_var_genes = pd.read_csv(MARS_var_genes_path)
SPID_var_genes = pd.read_csv(SPID_var_genes_path)

In [None]:
MARS_pc_genes = list(MARS_var_genes['Unnamed: 0'])
SPID_pc_genes = list(SPID_var_genes['Unnamed: 0'])

print(len(MARS_pc_genes), 
      len(set(adata_pc_new_vars_mars.var_names)),
      len(set(adata_pc_new_vars_SPID.var_names).intersection(set(adata_pc_new_vars_mars.var_names)))
     )

In [None]:
# genes = list(set(list(adata_pc_new_vars_SPID.var_names) + list(adata_pc_new_vars_mars.var_names))) + list(adata_pc_new_vars_mars.var_names)))
genes = list(set(SPID_pc_genes + MARS_pc_genes).intersection(set(adata_pc_new_vars.var_names)))

adata_pc_new_vars.layers[conf.scvi_settings.counts_layer_name] = adata_pc_new_vars.X.copy()
adata_pc_new_vars = adata_pc_new_vars[:, genes]
adata_pc_new_vars = adata_pc_new_vars.copy()

In [None]:
adata_pc_new_vars.obs["pc_annotation"] = adata_only_pc.obs["pc_annotation"]

In [None]:
adata_pc_new_vars

#### LDAE

In [None]:
model_ldae_new = train_scvi_model(adata_pc_new_vars, counts_layer = "counts", batch_key = "Method",
                         scvi_model_type=scvi.model.LinearSCVI, 
                         model_kwargs={"dropout_rate": 0.1},
                         trainer_kwargs={"batch_size": 512, 'max_epochs':300, 'plan_kwargs':{"lr": 5e-3}, 'check_val_every_n_epoch':10}
                        )

train_elbo = model_ldae_new.history["elbo_train"][1:]
test_elbo = model_ldae_new.history["elbo_validation"]

ax = train_elbo.plot()
test_elbo.plot(ax=ax)

In [None]:
loadings = model_ldae_new.get_loadings()
print(
    "Top loadings by magnitude\n---------------------------------------------------------------------------------------"
)
for clmn_ in loadings:
    loading_ = loadings[clmn_].sort_values()
    fstr = clmn_ + ":\t"
    fstr += "\t".join([f"{i}, {loading_[i]:.2}" for i in loading_.head(5).index])
    fstr += "\n\t...\n\t"
    fstr += "\t".join([f"{i}, {loading_[i]:.2}" for i in loading_.tail(5).index])
    print(
        fstr
        + "\n---------------------------------------------------------------------------------------\n"
    )

In [None]:
Z_hat = model_ldae_new.get_latent_representation()
for i, z in enumerate(Z_hat.T):
    adata_pc_new_vars.obs[f"Z_{i}"] = z

In [None]:
LDVI_LATENT_KEY = "X_LDVI"
LDVI_CLUSTERS_KEY = "leiden_LDVI"

adata_pc_new_vars.obsm[LDVI_LATENT_KEY] = Z_hat
sc.pp.neighbors(adata_pc_new_vars, use_rep=LDVI_LATENT_KEY, n_neighbors=20)
sc.tl.umap(adata_pc_new_vars, min_dist=0.3)

In [None]:
sc.pl.umap(adata_pc_new_vars, color=["Method", "pc_annotation"], wspace=0.3)

In [None]:
zs = [f"Z_{i}" for i in range(model_ldae_new.n_latent)]
sc.pl.umap(adata_pc_new_vars, color=zs)

In [None]:
# sc.tl.leiden(adata_pc_new_vars, key_added=LDVI_CLUSTERS_KEY, resolution=1.2)

#### train SCVI on new features

In [None]:
model_scvi_new = train_scvi_model(adata_pc_new_vars, counts_layer = "counts", batch_key = "Method",
                         scvi_model_type=scvi.model.SCVI, 
                         model_kwargs={"dropout_rate": 0.1},
                         trainer_kwargs={"batch_size": 512, 'max_epochs':350, 'plan_kwargs':{"lr": 5e-3}, 'check_val_every_n_epoch':10}
                        )

train_elbo = model_scvi_new.history["elbo_train"][1:]
test_elbo = model_scvi_new.history["elbo_validation"]

ax = train_elbo.plot()
test_elbo.plot(ax=ax)

In [None]:
scvi_latent_key = "X_SCVI"
adata_pc_new_vars.obsm[scvi_latent_key] = model_scvi_new.get_latent_representation()
sc.pp.neighbors(adata_pc_new_vars, use_rep=scvi_latent_key, n_neighbors=20)
sc.tl.umap(adata_pc_new_vars, min_dist=0.3)


In [None]:
sc.pl.umap(adata_pc_new_vars, color=["Method", "pc_annotation"],wspace=0.3)

### save

In [None]:
pc_LDAE_path = Path(conf.outputs.output_dir, f"adata_with_pc_LDAE_data_v_{data_version}_ts_{load_ts_iso}.h5ad")
adata_only_pc.write(pc_LDAE_path)
model_path = generate_path_in_output_dir(conf, f"LDAE_only_pc_model_data_v_{data_version}", add_date_timestamp=True)
model.save(model_path)