In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import feature.scrna_dataset as scrna_dataset
import model.sdes as sdes
import model.generate as generate
import model.scrna_ae as scrna_ae
import model.util as model_util
import analysis.fid as fid
import torch
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import os
import h5py

In [None]:
# Define device
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

### Define the branches and create the data loader

In [None]:
latent_space = False
latent_dim = 200

In [None]:
data_file = "/gstore/data/resbioai/tsenga5/branched_diffusion/data/scrna/covid_flu/processed/covid_flu_processed_reduced_genes.h5"
autoencoder_path = "/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_processed_reduced_genes_ldvae_d%d/" % latent_dim

# models_base_path = "/gstore/home/tsenga5/branched_diffusion/models/trained_models/scrna_covid_flu_continuous_latent_class_extension"

models_base_path = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/"

In [None]:
# TODO: this is currently rather inefficient; a decision-tree-style structure
# would be better

def class_time_to_branch(c, t, branch_defs):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t, branch_defs):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i, branch_defs) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def class_to_class_index_tensor(c, classes):
    """
    Given a tensor of classes, return the corresponding class indices
    as a tensor.
    """
    return torch.argmax(
        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1
    ).to(DEVICE)

In [None]:
# Define the branches
classes_01 = [0, 1]
branch_defs_01 = [((0, 1), 0.5795795795795796, 1), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_012 = [0, 1, 5]
branch_defs_012 = [((0, 1, 5), 6.786786786786787e-01, 1), ((0, 1), 0.5795795795795796, 0.6786786786786787), ((5,), 0, 0.6786786786786787), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_2 = [5]
branch_defs_2 = [((5,), 0, 0.6786786786786787)]

# classes_012 = [0, 1, 2]
# branch_defs_012 = [((0, 1, 2), 0.5795795795795796, 1), ((1, 2), 0.22922922922922923, 0.5795795795795796), ((2,), 0, 0.22922922922922923), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.22922922922922923)]

# classes_2 = [2]
# branch_defs_2 = [((2,), 0, 0.22922922922922923)]

# classes_012 = [0, 1, 3]
# branch_defs_012 = [((0, 1, 3), 0.5795795795795796, 1), ((1, 3), 0.5085085085085085, 0.5795795795795796), ((3,), 0, 0.5085085085085085), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5085085085085085)]

# classes_2 = [3]
# branch_defs_2 = [((3,), 0, 0.5085085085085085)]

dataset_01 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
# dataset_015 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
# dataset_5 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
dataset_012 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
dataset_2 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))

# Limit classes
inds_01 = np.isin(dataset_01.cell_cluster, classes_01)
dataset_01.data = dataset_01.data[inds_01]
dataset_01.cell_cluster = dataset_01.cell_cluster[inds_01]
# inds_015 = np.isin(dataset_015.cell_cluster, classes_015)
# dataset_015.data = dataset_015.data[inds_015]
# dataset_015.cell_cluster = dataset_015.cell_cluster[inds_015]
# inds_5 = np.isin(dataset_5.cell_cluster, classes_5)
# dataset_5.data = dataset_5.data[inds_5]
# dataset_5.cell_cluster = dataset_5.cell_cluster[inds_5]
inds_012 = np.isin(dataset_012.cell_cluster, classes_012)
dataset_012.data = dataset_012.data[inds_012]
dataset_012.cell_cluster = dataset_012.cell_cluster[inds_012]
inds_2 = np.isin(dataset_2.cell_cluster, classes_2)
dataset_2.data = dataset_2.data[inds_2]
dataset_2.cell_cluster = dataset_2.cell_cluster[inds_2]

data_loader_01 = torch.utils.data.DataLoader(dataset_01, batch_size=128, shuffle=True, num_workers=0)
# data_loader_015 = torch.utils.data.DataLoader(dataset_015, batch_size=128, shuffle=True, num_workers=0)
# data_loader_5 = torch.utils.data.DataLoader(dataset_5, batch_size=128, shuffle=True, num_workers=0)
data_loader_012 = torch.utils.data.DataLoader(dataset_012, batch_size=128, shuffle=True, num_workers=0)
data_loader_2 = torch.utils.data.DataLoader(dataset_2, batch_size=128, shuffle=True, num_workers=0)
input_shape = next(iter(data_loader_01))[0].shape[1:]

In [None]:
# Create the SDE
sde = sdes.VariancePreservingSDE(0.1, 5, input_shape)

t_limit = 1

In [None]:
# os.environ["MODEL_DIR"] = os.path.join(models_base_path, "extension")
os.environ["MODEL_DIR"] = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/extension"

import model.train_continuous_model as train_continuous_model  # Import this AFTER setting environment

#### Train extra branch on branched model

In [None]:
def map_branch_def(branch_def, target_branch_defs):
    """
    Given a particular branch definition (i.e. a triplet), and a
    list of branch definitions, attempts to match that branch
    definition to the corresponding entry in the list. This
    mapping is based on whether or not the branch would need to be
    retrained. The query `branch_def` is matched to a target within
    `branch_defs` if the target's class indices are all present in
    the query, and the query time is a sub-interval of the target
    time.
    Arguments:
        `branch_def`: a branch definition (i.e. triplet of class index
            tuple, start time, and end time)
        `target_branch_defs`: a list of branch definitions
    Returns the index of the matched branch definition in `branch_defs`,
    or -1 if there is no suitable match found.
    """
    for i, target_branch_def in enumerate(target_branch_defs):
        if set(branch_def[0]).issuperset(set(target_branch_def[0])) \
            and branch_def[1] >= target_branch_def[1] \
            and branch_def[2] <= target_branch_def[2]:
            return i
    return -1

In [None]:
branched_model_1 = model_util.load_model(
    scrna_ae.MultitaskResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_branched_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [None]:
# Generate the samples
branched_samples_before = {}
for class_to_sample in classes_01:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_1, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_01),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_before[class_to_sample] = sample.cpu().numpy()

In [None]:
# Create new model and copy over parameters
branched_model_2 = scrna_ae.MultitaskResNet(
    len(branch_defs_012), input_shape[0], t_limit=t_limit
).to(DEVICE)

# Figure out which branches should be copied over to which ones
branch_map_inds = [
    map_branch_def(bd, branch_defs_01) for bd in branch_defs_012
]

# For each submodule, copy over the weights
# Careful: this assumes a particular kind of architecture!
modules_1 = dict(branched_model_1.named_children())
modules_2 = dict(branched_model_2.named_children())

for module_name in ["layers", "time_embedders"]:
    for submodule_i, submodule in enumerate(modules_1[module_name]):
        if len(submodule) == 1:
            branched_model_2.get_submodule(module_name)[submodule_i].load_state_dict(
                submodule.state_dict()
            )
        elif len(submodule) == len(branch_defs_01):
            target_submodule_list = branched_model_2.get_submodule(module_name)[submodule_i]
            for target_i, source_i in enumerate(branch_map_inds):
                if source_i != -1:
                    target_submodule_list[target_i].load_state_dict(
                        submodule[source_i].state_dict()
                    )
                else:
                    # Copy over some other branch for a warm start
                    # We'll manually set it for now (TODO)
                    source_i = -1  # Last branch
                    target_submodule_list[target_i].load_state_dict(
                        submodule[source_i].state_dict()
                    )
        else:
            raise ValueError("Found module list of length %d" % len(module_list))

submodule = branched_model_1.get_submodule("last_linears")
target_submodule_list = branched_model_2.get_submodule("last_linears")
for target_i, source_i in enumerate(branch_map_inds):
    if source_i != -1:
        target_submodule_list[target_i].load_state_dict(
            submodule[source_i].state_dict()
        )
    else:
        # Copy over some other branch for a warm start
        # We'll manually set it for now (TODO)
        source_i = -1  # Last branch
        target_submodule_list[target_i].load_state_dict(
            submodule[source_i].state_dict()
        )

In [None]:
# Generate the samples again to make sure match-up was done correctly
branched_samples_before_2 = {}
for class_to_sample in classes_01:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_2, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_before_2[class_to_sample] = sample.cpu().numpy()

In [None]:
# Train the model, for the specific branches only

# Freeze all shared layers of the model, and freeze all task-specific
# layers other than the ones we want to train
for module_name in ["layers", "time_embedders"]:
    for submodule in branched_model_2.get_submodule(module_name):
        if len(submodule) == 1:
            for p in submodule.parameters():
                p.requires_grad = False
        elif len(submodule) == len(branch_defs_012):
            for i in range(len(submodule)):
                if branch_map_inds[i] != -1:
                    for p in submodule[i].parameters():
                        p.requires_grad = False
                else:
                    for p in submodule[i].parameters():
                        p.requires_grad = True
        else:
            raise ValueError("Found module list of length %d" % len(submodule))
submodule = branched_model_2.get_submodule("last_linears")
for i in range(len(submodule)):
    if branch_map_inds[i] != -1:
        for p in submodule[i].parameters():
            p.requires_grad = False
    else:
        for p in submodule[i].parameters():
            p.requires_grad = True

# Train
train_continuous_model.train_ex.run(
    "train_branched_model",
    config_updates={
        "model": branched_model_2,
        "sde": sde,
        "data_loader": data_loader_2,
        "class_time_to_branch_index": lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        "num_epochs": 100,
        "learning_rate": 0.001,
        "t_limit": branch_defs_2[0][2],
        "loss_weighting_type": "empirical_norm"
    }
)

In [None]:
# Generate the samples
branched_samples_after = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_2, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_after[class_to_sample] = sample.cpu().numpy()

#### Train label-guided model with only new label

In [None]:
# Import the label-guided model
label_guided_model_1 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [None]:
linear_samples_before = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_1, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_before[class_to_sample] = sample.cpu().numpy()

In [None]:
# Train on only new label
train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": label_guided_model_1,
        "sde": sde,
        "data_loader": data_loader_2,
        "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_012),
        "num_epochs": 10,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

In [None]:
linear_samples_after_newonly = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_1, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_after_newonly[class_to_sample] = sample.cpu().numpy()

#### Train label-guided model with all data

In [None]:
# Import the label-guided model
label_guided_model_2 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [None]:
# Train on all data
train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": label_guided_model_2,
        "sde": sde,
        "data_loader": data_loader_012,
        "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_012),
        "num_epochs": 30,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

In [None]:
linear_samples_after_all = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_2, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_after_all[class_to_sample] = sample.cpu().numpy()

#### Compute FIDs

In [None]:
# Sample objects from the original dataset
true_samples = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    inds = np.where(dataset_012.cell_cluster == class_to_sample)[0]
    sample_inds = np.random.choice(inds, size=1000, replace=False)
    true_samples[class_to_sample] = dataset_012.data[sample_inds]

In [None]:
if not latent_space:
    dataset_with_ae = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)

def compute_fid(gen_samples, true_samples, latent=True):
    if latent_space:
        if latent:
            return fid.compute_fid(
                gen_samples,
                dataset_01.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()
            )
        else:
            return fid.compute_fid(
                dataset_01.decode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),
                true_samples
            )
    else:
        gen_samples[gen_samples < 0] = 0  # Generated values should never be above 0
        if latent:
            return fid.compute_fid(
                dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),
                dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()
            )
        else:
            return fid.compute_fid(
                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE))).cpu().numpy(),
                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE))).cpu().numpy()
            )

In [None]:
branched_before_fids = {}
branched_before_2_fids = {}
branched_after_fids = {}
linear_before_fids = {}
linear_after_newonly_fids = {}
linear_after_all_fids = {}

latent = True

for c in branched_samples_before.keys():
    branched_before_fids[c] = compute_fid(branched_samples_before[c], true_samples[c], latent)
for c in branched_samples_before_2.keys():
    branched_before_2_fids[c] = compute_fid(branched_samples_before_2[c], true_samples[c], latent)
for c in branched_samples_after.keys():
    branched_after_fids[c] = compute_fid(branched_samples_after[c], true_samples[c], latent)
for c in linear_samples_before.keys():
    linear_before_fids[c] = compute_fid(linear_samples_before[c], true_samples[c], latent)
for c in linear_samples_after_newonly.keys():
    linear_after_newonly_fids[c] = compute_fid(linear_samples_after_newonly[c], true_samples[c], latent)
for c in linear_samples_after_all.keys():
    linear_after_all_fids[c] = compute_fid(linear_samples_after_all[c], true_samples[c], latent)

In [None]:
print("B-before", branched_before_fids)
print("B-before2", branched_before_2_fids)
print("B-after", branched_after_fids)
print("L-before", linear_before_fids)
print("L-afterone", linear_after_newonly_fids)
print("L-afterall", linear_after_all_fids)