#### This script integrates the new replicates into a joint embedding with the old replicates. This ultimately ends up getting recalculated at the end of the script 

In [None]:
import scanpy as sc
import scvi
from tqdm.notebook import tqdm
import os
import numpy as np
import pandas as pd
from scvi.model.utils import mde
import matplotlib.pyplot as plt
import pickle
import glob

import seaborn as sns
from sklearn.metrics import classification_report
import torch

In [None]:
sc.set_figure_params(figsize=(4, 4))

Add the paths to all folders for experiments to be integrated

In [None]:
input_folders = glob.glob("/mnt/sata1/Analysis_Alex/timecourse_replicates/day*")

In [None]:
output_folder = (
    r"/projects/2023_Spatial_Paper/Analysis_Alex/timecourse_replicates/analysis/cleaned"
)

put the path to the final adata from replicate 1 processing

In [None]:
reference_adata = (
    r"D:/amonell/timecourse_final/analysis/cleaned/final_celltyped_and_axes.h5ad"
)

In [None]:
reference_adata = sc.read(reference_adata)

Run scVI integration

In [None]:
scvi.model.SCVI.setup_anndata(reference_adata, batch_key="batch", layer="raw")

In [None]:
torch.set_float32_matmul_precision("high")

In [None]:
scvi_ref = scvi.model.SCVI(
    reference_adata, n_layers=2, n_latent=30, gene_likelihood="nb"
)
scvi_ref.train()

In [None]:
SCVI_LATENT_KEY = "X_scVI_replicates"

reference_adata.obsm[SCVI_LATENT_KEY] = scvi_ref.get_latent_representation()
sc.pp.neighbors(reference_adata, use_rep=SCVI_LATENT_KEY)

In [None]:
sc.tl.umap(reference_adata)
sc.pl.umap(reference_adata, color="Subtype")

### Map new replicates to this joint embedding

In [None]:
for t in range(len(input_folders)):
    target_adata = sc.read(os.path.join(input_folders[t], "adatas", "05_unrolled.h5ad"))
    target_adata.obs["batch"] = os.path.basename(input_folders[t])
    scvi.model.SCVI.prepare_query_anndata(target_adata, scvi_ref)
    scvi_query = scvi.model.SCVI.load_query_data(target_adata, scvi_ref)

    scvi_query.train(max_epochs=200, plan_kwargs={"weight_decay": 0.0})
    target_adata.obsm[SCVI_LATENT_KEY] = scvi_query.get_latent_representation()
    target_adata.write(
        os.path.join(input_folders[t], "adatas", "06_reference_mapped.h5ad")
    )