In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import feature.scrna_dataset as scrna_dataset
import model.sdes as sdes
import model.generate as generate
import model.scrna_ae as scrna_ae
import model.util as model_util
import analysis.fid as fid
import torch
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import os
import h5py

In [2]:
# Define device
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

### Define the branches and create the data loader

In [3]:
latent_space = False
latent_dim = 200

In [4]:
data_file = "/gstore/data/resbioai/tsenga5/branched_diffusion/data/scrna/covid_flu/processed/covid_flu_processed_reduced_genes.h5"
autoencoder_path = "/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_processed_reduced_genes_ldvae_d%d/" % latent_dim

models_base_path = "/gstore/home/tsenga5/branched_diffusion/models/trained_models/scrna_covid_flu_continuous_class_extension"

In [5]:
# TODO: this is currently rather inefficient; a decision-tree-style structure
# would be better

def class_time_to_branch(c, t, branch_defs):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t, branch_defs):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i, branch_defs) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def class_to_class_index_tensor(c, classes):
    """
    Given a tensor of classes, return the corresponding class indices
    as a tensor.
    """
    return torch.argmax(
        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1
    ).to(DEVICE)

In [6]:
# Define the branches
classes_01 = [0, 1]
branch_defs_01 = [((0, 1), 0.5795795795795796, 1), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_012 = [0, 1, 5]
branch_defs_012 = [((0, 1, 5), 6.786786786786787e-01, 1), ((0, 1), 0.5795795795795796, 0.6786786786786787), ((5,), 0, 0.6786786786786787), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_2 = [5]
branch_defs_2 = [((5,), 0, 0.6786786786786787)]

dataset_01 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
dataset_012 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))
dataset_2 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))

# Limit classes
inds_01 = np.isin(dataset_01.cell_cluster, classes_01)
dataset_01.data = dataset_01.data[inds_01]
dataset_01.cell_cluster = dataset_01.cell_cluster[inds_01]
inds_012 = np.isin(dataset_012.cell_cluster, classes_012)
dataset_012.data = dataset_012.data[inds_012]
dataset_012.cell_cluster = dataset_012.cell_cluster[inds_012]
inds_2 = np.isin(dataset_2.cell_cluster, classes_2)
dataset_2.data = dataset_2.data[inds_2]
dataset_2.cell_cluster = dataset_2.cell_cluster[inds_2]

data_loader_01 = torch.utils.data.DataLoader(dataset_01, batch_size=128, shuffle=True, num_workers=0)
data_loader_012 = torch.utils.data.DataLoader(dataset_012, batch_size=128, shuffle=True, num_workers=0)
data_loader_2 = torch.utils.data.DataLoader(dataset_2, batch_size=128, shuffle=True, num_workers=0)
input_shape = next(iter(data_loader_01))[0].shape[1:]

In [7]:
# Create the SDE
sde = sdes.VariancePreservingSDE(0.1, 5, input_shape)

t_limit = 1

In [8]:
os.environ["MODEL_DIR"] = os.path.join(models_base_path, "extension")

import model.train_continuous_model as train_continuous_model  # Import this AFTER setting environment

#### Train extra branch on branched model

In [9]:
def map_branch_def(branch_def, target_branch_defs):
    """
    Given a particular branch definition (i.e. a triplet), and a
    list of branch definitions, attempts to match that branch
    definition to the corresponding entry in the list. This
    mapping is based on whether or not the branch would need to be
    retrained. The query `branch_def` is matched to a target within
    `branch_defs` if the target's class indices are all present in
    the query, and the query time is a sub-interval of the target
    time.
    Arguments:
        `branch_def`: a branch definition (i.e. triplet of class index
            tuple, start time, and end time)
        `target_branch_defs`: a list of branch definitions
    Returns the index of the matched branch definition in `branch_defs`,
    or -1 if there is no suitable match found.
    """
    for i, target_branch_def in enumerate(target_branch_defs):
        if set(branch_def[0]).issuperset(set(target_branch_def[0])) \
            and branch_def[1] >= target_branch_def[1] \
            and branch_def[2] <= target_branch_def[2]:
            return i
    return -1

In [10]:
branched_model_1 = model_util.load_model(
    scrna_ae.MultitaskResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_branched_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [11]:
# Generate the samples
branched_samples_before = {}
for class_to_sample in classes_01:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_1, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_01),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_before[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.58it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.62it/s]


In [12]:
# Create new model and copy over parameters
branched_model_2 = scrna_ae.MultitaskResNet(
    len(branch_defs_012), input_shape[0], t_limit=t_limit
).to(DEVICE)

# Figure out which branches should be copied over to which ones
branch_map_inds = [
    map_branch_def(bd, branch_defs_01) for bd in branch_defs_012
]

# For each submodule, copy over the weights
# Careful: this assumes a particular kind of architecture!
modules_1 = dict(branched_model_1.named_children())
modules_2 = dict(branched_model_2.named_children())

for module_name in ["layers", "time_embedders"]:
    for submodule_i, submodule in enumerate(modules_1[module_name]):
        if len(submodule) == 1:
            branched_model_2.get_submodule(module_name)[submodule_i].load_state_dict(
                submodule.state_dict()
            )
        elif len(submodule) == len(branch_defs_01):
            target_submodule_list = branched_model_2.get_submodule(module_name)[submodule_i]
            for target_i, source_i in enumerate(branch_map_inds):
                if source_i != -1:
                    target_submodule_list[target_i].load_state_dict(
                        submodule[source_i].state_dict()
                    )
                else:
                    # Copy over some other branch for a warm start
                    # We'll manually set it for now (TODO)
                    source_i = -1  # Last branch
                    target_submodule_list[target_i].load_state_dict(
                        submodule[source_i].state_dict()
                    )
        else:
            raise ValueError("Found module list of length %d" % len(module_list))

submodule = branched_model_1.get_submodule("last_linears")
target_submodule_list = branched_model_2.get_submodule("last_linears")
for target_i, source_i in enumerate(branch_map_inds):
    if source_i != -1:
        target_submodule_list[target_i].load_state_dict(
            submodule[source_i].state_dict()
        )
    else:
        # Copy over some other branch for a warm start
        # We'll manually set it for now (TODO)
        source_i = -1  # Last branch
        target_submodule_list[target_i].load_state_dict(
            submodule[source_i].state_dict()
        )

In [13]:
# Generate the samples again to make sure match-up was done correctly
branched_samples_before_2 = {}
for class_to_sample in classes_01:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_2, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_before_2[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.62it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.61it/s]


In [14]:
# Train the model, for the specific branches only

# Freeze all shared layers of the model, and freeze all task-specific
# layers other than the ones we want to train
for module_name in ["layers", "time_embedders"]:
    for submodule in branched_model_2.get_submodule(module_name):
        if len(submodule) == 1:
            for p in submodule.parameters():
                p.requires_grad = False
        elif len(submodule) == len(branch_defs_012):
            for i in range(len(submodule)):
                if branch_map_inds[i] != -1:
                    for p in submodule[i].parameters():
                        p.requires_grad = False
                else:
                    for p in submodule[i].parameters():
                        p.requires_grad = True
        else:
            raise ValueError("Found module list of length %d" % len(submodule))
submodule = branched_model_2.get_submodule("last_linears")
for i in range(len(submodule)):
    if branch_map_inds[i] != -1:
        for p in submodule[i].parameters():
            p.requires_grad = False
    else:
        for p in submodule[i].parameters():
            p.requires_grad = True

# Train
train_continuous_model.train_ex.run(
    "train_branched_model",
    config_updates={
        "model": branched_model_2,
        "sde": sde,
        "data_loader": data_loader_2,
        "class_time_to_branch_index": lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        "num_epochs": 100,
        "learning_rate": 0.001,
        "t_limit": branch_defs_2[0][2],
        "loss_weighting_type": "empirical_norm"
    }
)

INFO - train - Running command 'train_branched_model'
INFO - train - Started run with ID "1"
Loss: 541.89: 100%|██████████████████████████████████████████| 24/24 [00:06<00:00,  3.57it/s]


Epoch 1 average Loss: 934.59


Loss: 322.41: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 2 average Loss: 394.38


Loss: 262.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 3 average Loss: 290.47


Loss: 236.19: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 4 average Loss: 248.73


Loss: 223.13: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 5 average Loss: 229.84


Loss: 214.03: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 6 average Loss: 214.70


Loss: 207.60: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 7 average Loss: 209.96


Loss: 207.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 8 average Loss: 208.24


Loss: 199.02: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 9 average Loss: 204.64


Loss: 219.66: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 10 average Loss: 205.58


Loss: 201.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 11 average Loss: 204.01


Loss: 200.55: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 12 average Loss: 202.39


Loss: 201.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 13 average Loss: 200.10


Loss: 208.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 14 average Loss: 202.08


Loss: 206.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 15 average Loss: 201.40


Loss: 193.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]


Epoch 16 average Loss: 199.07


Loss: 194.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 17 average Loss: 200.06


Loss: 195.51: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 18 average Loss: 200.14


Loss: 197.59: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]


Epoch 19 average Loss: 199.59


Loss: 214.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]


Epoch 20 average Loss: 199.34


Loss: 203.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]


Epoch 21 average Loss: 200.05


Loss: 203.20: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 22 average Loss: 199.25


Loss: 204.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 23 average Loss: 198.07


Loss: 197.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 24 average Loss: 197.75


Loss: 197.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 25 average Loss: 196.30


Loss: 193.27: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 26 average Loss: 195.22


Loss: 194.59: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 27 average Loss: 195.88


Loss: 194.37: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 28 average Loss: 193.41


Loss: 191.63: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 29 average Loss: 194.82


Loss: 194.82: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 30 average Loss: 193.57


Loss: 197.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 31 average Loss: 192.92


Loss: 193.93: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 32 average Loss: 193.21


Loss: 195.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 33 average Loss: 193.24


Loss: 188.00: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]


Epoch 34 average Loss: 190.22


Loss: 191.57: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 35 average Loss: 189.02


Loss: 180.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 36 average Loss: 189.45


Loss: 187.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 37 average Loss: 187.70


Loss: 192.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 38 average Loss: 188.94


Loss: 179.81: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 39 average Loss: 186.43


Loss: 186.99: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 40 average Loss: 188.21


Loss: 180.00: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 41 average Loss: 184.47


Loss: 180.65: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 42 average Loss: 184.01


Loss: 178.03: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 43 average Loss: 182.47


Loss: 176.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 44 average Loss: 182.72


Loss: 182.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 45 average Loss: 183.89


Loss: 179.27: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 46 average Loss: 181.67


Loss: 186.57: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 47 average Loss: 181.42


Loss: 181.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.98it/s]


Epoch 48 average Loss: 181.16


Loss: 176.92: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 49 average Loss: 180.07


Loss: 184.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 50 average Loss: 179.34


Loss: 176.39: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 51 average Loss: 178.82


Loss: 177.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 52 average Loss: 180.16


Loss: 170.51: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 53 average Loss: 177.33


Loss: 185.37: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 54 average Loss: 178.77


Loss: 186.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 55 average Loss: 178.24


Loss: 176.18: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 56 average Loss: 177.40


Loss: 170.73: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 57 average Loss: 178.27


Loss: 177.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 58 average Loss: 176.12


Loss: 179.95: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 59 average Loss: 177.06


Loss: 183.54: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 60 average Loss: 175.83


Loss: 172.75: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 61 average Loss: 174.99


Loss: 176.54: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 62 average Loss: 173.56


Loss: 167.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 63 average Loss: 175.57


Loss: 171.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 64 average Loss: 171.87


Loss: 180.66: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 65 average Loss: 175.43


Loss: 165.39: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 66 average Loss: 172.45


Loss: 180.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 67 average Loss: 171.16


Loss: 190.52: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 68 average Loss: 171.96


Loss: 180.63: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 69 average Loss: 173.21


Loss: 169.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 70 average Loss: 172.09


Loss: 172.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 71 average Loss: 171.75


Loss: 163.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 72 average Loss: 169.62


Loss: 176.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 73 average Loss: 171.32


Loss: 179.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 74 average Loss: 169.29


Loss: 173.94: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  5.00it/s]


Epoch 75 average Loss: 167.70


Loss: 167.65: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 76 average Loss: 169.27


Loss: 173.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 77 average Loss: 168.66


Loss: 161.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 78 average Loss: 168.91


Loss: 169.10: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 79 average Loss: 170.50


Loss: 162.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 80 average Loss: 168.97


Loss: 177.53: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 81 average Loss: 169.51


Loss: 167.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 82 average Loss: 167.77


Loss: 164.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 83 average Loss: 166.51


Loss: 173.80: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 84 average Loss: 168.70


Loss: 158.83: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 85 average Loss: 167.01


Loss: 170.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 86 average Loss: 166.36


Loss: 165.91: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 87 average Loss: 167.18


Loss: 171.36: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 88 average Loss: 166.90


Loss: 165.73: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]


Epoch 89 average Loss: 166.48


Loss: 167.68: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 90 average Loss: 164.47


Loss: 157.61: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 91 average Loss: 165.10


Loss: 165.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]


Epoch 92 average Loss: 165.88


Loss: 166.33: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]


Epoch 93 average Loss: 164.01


Loss: 165.82: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 94 average Loss: 166.55


Loss: 159.49: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 95 average Loss: 164.26


Loss: 160.49: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 96 average Loss: 164.98


Loss: 163.95: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 97 average Loss: 164.67


Loss: 160.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 98 average Loss: 163.01


Loss: 170.42: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 99 average Loss: 163.36


Loss: 173.83: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]


Epoch 100 average Loss: 162.93


INFO - train - Completed after 0:13:26


<sacred.run.Run at 0x2aab88d76eb0>

In [15]:
# Generate the samples
branched_samples_after = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_2, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    branched_samples_after[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.66it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.63it/s]


Sampling class: 5


100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.62it/s]


#### Train label-guided model with only new label

In [16]:
# Import the label-guided model
label_guided_model_1 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [17]:
linear_samples_before = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_1, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_before[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.78it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.76it/s]


Sampling class: 5


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.77it/s]


In [18]:
# Train on only new label
train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": label_guided_model_1,
        "sde": sde,
        "data_loader": data_loader_2,
        "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_012),
        "num_epochs": 10,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

INFO - train - Running command 'train_label_guided_model'
INFO - train - Started run with ID "2"
Loss: 188.88: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.52it/s]


Epoch 1 average Loss: 228.78


Loss: 164.17: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.58it/s]


Epoch 2 average Loss: 171.67


Loss: 159.49: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.59it/s]


Epoch 3 average Loss: 161.44


Loss: 156.18: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.59it/s]


Epoch 4 average Loss: 152.68


Loss: 150.89: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]


Epoch 5 average Loss: 148.06


Loss: 140.66: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.60it/s]


Epoch 6 average Loss: 146.70


Loss: 145.87: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]


Epoch 7 average Loss: 142.78


Loss: 140.37: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.60it/s]


Epoch 8 average Loss: 140.87


Loss: 143.83: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]


Epoch 9 average Loss: 137.89


Loss: 137.81: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]


Epoch 10 average Loss: 136.16


INFO - train - Completed after 0:00:52


<sacred.run.Run at 0x2aab88dda910>

In [19]:
linear_samples_after_newonly = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_1, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_after_newonly[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.80it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.79it/s]


Sampling class: 5


100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.78it/s]


#### Train label-guided model with all data

In [34]:
# Import the label-guided model
label_guided_model_2 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth")
).to(DEVICE)

In [35]:
# Train on all data
train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": label_guided_model_2,
        "sde": sde,
        "data_loader": data_loader_012,
        "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_012),
        "num_epochs": 30,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

INFO - train - Running command 'train_label_guided_model'
INFO - train - Started run with ID "5"
Loss: 94.82: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 1 average Loss: 84.54


Loss: 72.43: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 2 average Loss: 79.13


Loss: 81.77: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 3 average Loss: 78.28


Loss: 82.70: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 4 average Loss: 78.02


Loss: 84.01: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 5 average Loss: 77.01


Loss: 91.03: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 6 average Loss: 76.50


Loss: 74.29: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 7 average Loss: 75.05


Loss: 76.08: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 8 average Loss: 76.26


Loss: 71.18: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 9 average Loss: 74.69


Loss: 68.64: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]


Epoch 10 average Loss: 75.23


Loss: 79.51: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 11 average Loss: 74.38


Loss: 87.25: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 12 average Loss: 74.67


Loss: 75.56: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 13 average Loss: 74.07


Loss: 84.54: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 14 average Loss: 73.26


Loss: 91.69: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 15 average Loss: 73.42


Loss: 75.05: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 16 average Loss: 74.01


Loss: 70.85: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 17 average Loss: 73.73


Loss: 75.24: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 18 average Loss: 72.69


Loss: 71.01: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 19 average Loss: 72.76


Loss: 71.26: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.43it/s]


Epoch 20 average Loss: 71.95


Loss: 86.02: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]


Epoch 21 average Loss: 71.86


Loss: 75.04: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 22 average Loss: 72.07


Loss: 67.58: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 23 average Loss: 71.33


Loss: 75.39: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]


Epoch 24 average Loss: 71.99


Loss: 76.64: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 25 average Loss: 71.06


Loss: 72.59: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 26 average Loss: 71.66


Loss: 70.42: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 27 average Loss: 70.71


Loss: 76.20: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.43it/s]


Epoch 28 average Loss: 70.91


Loss: 83.29: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 29 average Loss: 71.17


Loss: 70.21: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]


Epoch 30 average Loss: 70.25


INFO - train - Completed after 0:07:55


<sacred.run.Run at 0x2aac7a367280>

In [36]:
linear_samples_after_all = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_2, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_012),
        sampler="pc", t_limit=t_limit, num_samples=1000, verbose=True
    )
    linear_samples_after_all[class_to_sample] = sample.cpu().numpy()

Sampling class: 0


100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.81it/s]


Sampling class: 1


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.75it/s]


Sampling class: 5


100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.73it/s]


#### Compute FIDs

In [37]:
# Sample objects from the original dataset
true_samples = {}
for class_to_sample in classes_012:
    print("Sampling class: %s" % class_to_sample)
    inds = np.where(dataset_012.cell_cluster == class_to_sample)[0]
    sample_inds = np.random.choice(inds, size=1000, replace=False)
    true_samples[class_to_sample] = dataset_012.data[sample_inds]

Sampling class: 0
Sampling class: 1
Sampling class: 5


In [38]:
if not latent_space:
    dataset_with_ae = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)

def compute_fid(gen_samples, true_samples, latent=True):
    if latent_space:
        if latent:
            return fid.compute_fid(
                gen_samples,
                dataset_01.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()
            )
        else:
            return fid.compute_fid(
                dataset_01.decode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),
                true_samples
            )
    else:
        gen_samples[gen_samples < 0] = 0  # Generated values should never be above 0
        if latent:
            return fid.compute_fid(
                dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),
                dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()
            )
        else:
            return fid.compute_fid(
                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE))).cpu().numpy(),
                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE))).cpu().numpy()
            )

[34mINFO    [0m File                                                                                                      
         [35m/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_proc[0m
         [35messed_reduced_genes_ldvae_d200/[0m[95mmodel.pt[0m already downloaded                                                




In [39]:
branched_before_fids = {}
branched_before_2_fids = {}
branched_after_fids = {}
linear_before_fids = {}
linear_after_newonly_fids = {}
linear_after_all_fids = {}

latent = True

for c in branched_samples_before.keys():
    branched_before_fids[c] = compute_fid(branched_samples_before[c], true_samples[c], latent)
for c in branched_samples_before_2.keys():
    branched_before_2_fids[c] = compute_fid(branched_samples_before_2[c], true_samples[c], latent)
for c in branched_samples_after.keys():
    branched_after_fids[c] = compute_fid(branched_samples_after[c], true_samples[c], latent)
for c in linear_samples_before.keys():
    linear_before_fids[c] = compute_fid(linear_samples_before[c], true_samples[c], latent)
for c in linear_samples_after_newonly.keys():
    linear_after_newonly_fids[c] = compute_fid(linear_samples_after_newonly[c], true_samples[c], latent)
for c in linear_samples_after_all.keys():
    linear_after_all_fids[c] = compute_fid(linear_samples_after_all[c], true_samples[c], latent)

In [40]:
print("B-before", branched_before_fids)
print("B-before2", branched_before_2_fids)
print("B-after", branched_after_fids)
print("L-before", linear_before_fids)
print("L-afterone", linear_after_newonly_fids)
print("L-afterall", linear_after_all_fids)

B-before {0: 17.91029206726665, 1: 20.768634314998813}
B-before2 {0: 18.08620742509962, 1: 21.22780248367854}
B-after {0: 18.08532807016608, 1: 20.537572760461792, 5: 23.560019574154172}
L-before {0: 20.431972237481226, 1: 22.235712766699923, 5: 27.880261054677923}
L-afterone {0: 25.624702214950076, 1: 26.794907896162005, 5: 32.10889066394775}
L-afterall {0: 16.95639804145786, 1: 19.29993126245404, 5: 28.023053225371605}
