In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import feature.scrna_dataset as scrna_dataset
import model.sdes as sdes
import model.generate as generate
import model.scrna_ae as scrna_ae
import model.util as model_util
import analysis.fid as fid
import torch
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import os
import h5py

In [2]:
# Define device
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

### Define the branches and create the data loader

In [3]:
data_file = "/gstore/data/resbioai/tsenga5/branched_diffusion/data/scrna/covid_flu/processed/covid_flu_processed_reduced_genes.h5"
autoencoder_path = "/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_processed_reduced_genes_ldvae_d200/"
models_base_path = "/gstore/home/tsenga5/branched_diffusion/models/trained_models/scrna_covid_flu_continuous_latent_class_extension"

In [4]:
# TODO: this is currently rather inefficient; a decision-tree-style structure
# would be better

def class_time_to_branch(c, t, branch_defs):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t, branch_defs):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i, branch_defs) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def class_to_class_index_tensor(c, classes):
    """
    Given a tensor of classes, return the corresponding class indices
    as a tensor.
    """
    return torch.argmax(
        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1
    ).to(DEVICE)

In [5]:
# Define the branches
classes_01 = [0, 1]
branch_defs_01 = [((0, 1), 0.5795795795795796, 1), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_015 = [0, 1, 5]
branch_defs_015 = [((0, 1, 5), 6.786786786786787e-01, 1), ((0, 1), 0.5795795795795796, 0.6786786786786787), ((5,), 0, 0.6786786786786787), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]

classes_5 = [5]
branch_defs_5 = [((5,), 0, 0.6786786786786787)]

dataset_01 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)
dataset_015 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)
dataset_5 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)

# Limit classes
inds_01 = np.isin(dataset_01.cell_cluster, classes_01)
dataset_01.data = dataset_01.data[inds_01]
dataset_01.cell_cluster = dataset_01.cell_cluster[inds_01]
inds_015 = np.isin(dataset_015.cell_cluster, classes_015)
dataset_015.data = dataset_015.data[inds_015]
dataset_015.cell_cluster = dataset_015.cell_cluster[inds_015]
inds_5 = np.isin(dataset_5.cell_cluster, classes_5)
dataset_5.data = dataset_5.data[inds_5]
dataset_5.cell_cluster = dataset_5.cell_cluster[inds_5]

data_loader_01 = torch.utils.data.DataLoader(dataset_01, batch_size=128, shuffle=True, num_workers=0)
data_loader_015 = torch.utils.data.DataLoader(dataset_015, batch_size=128, shuffle=True, num_workers=0)
data_loader_5 = torch.utils.data.DataLoader(dataset_5, batch_size=128, shuffle=True, num_workers=0)
input_shape = next(iter(data_loader_01))[0].shape[1:]

[rank: 0] Global seed set to 0


[34mINFO    [0m File                                                                                                      
         [35m/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_proc[0m
         [35messed_reduced_genes_ldvae_d200/[0m[95mmodel.pt[0m already downloaded                                                


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[34mINFO    [0m File                                                                                                      
         [35m/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_proc[0m
         [35messed_reduced_genes_ldvae_d200/[0m[95mmodel.pt[0m already downloaded                                                




[34mINFO    [0m File                                                                                                      
         [35m/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_proc[0m
         [35messed_reduced_genes_ldvae_d200/[0m[95mmodel.pt[0m already downloaded                                                




In [6]:
# Create the SDE
sde = sdes.VariancePreservingSDE(0.1, 5, input_shape)

t_limit = 1

In [7]:
# Import gene names and marker genes
with h5py.File(data_file, "r") as f:
    gene_names = f["gene_names"][:].astype(str)
    marker_genes = {}
    for class_i in classes_015:
        marker_genes[class_i] = f["marker_genes"][str(class_i)][:].astype(str)
    
genes_of_interest = np.array(["NFKB1", "NFKB2", "IRF1", "CXCR3", "CXCL10", "STAT1", "TLR4", "TGFB1", "IL1B", "IFNG", "TLR4", "TNFSF4", "IL1R2", "IL1B", "IL7R", "IL32"])

In [8]:
os.environ["MODEL_DIR"] = os.path.join(models_base_path, "extension")

import model.train_continuous_model as train_continuous_model  # Import this AFTER setting environment

#### Train extra branch on branched model

In [9]:
def map_branch_def(branch_def, target_branch_defs):
    """
    Given a particular branch definition (i.e. a triplet), and a
    list of branch definitions, attempts to match that branch
    definition to the corresponding entry in the list. This
    mapping is based on whether or not the branch would need to be
    retrained. The query `branch_def` is matched to a target within
    `branch_defs` if the target's class indices are all present in
    the query, and the query time is a sub-interval of the target
    time.
    Arguments:
        `branch_def`: a branch definition (i.e. triplet of class index
            tuple, start time, and end time)
        `target_branch_defs`: a list of branch definitions
    Returns the index of the matched branch definition in `branch_defs`,
    or -1 if there is no suitable match found.
    """
    for i, target_branch_def in enumerate(target_branch_defs):
        if set(branch_def[0]).issuperset(set(target_branch_def[0])) \
            and branch_def[1] >= target_branch_def[1] \
            and branch_def[2] <= target_branch_def[2]:
            return i
    return -1

In [10]:
branched_model_1 = model_util.load_model(
    scrna_ae.MultitaskResNet,
    os.path.join(models_base_path, "branched/01/1/last_ckpt.pth")
).to(DEVICE)

In [11]:
# # Create new model and copy over parameters
# branched_model_2 = scrna_ae.MultitaskResNet(
#     len(branch_defs_015), input_shape[0], t_limit=t_limit
# ).to(DEVICE)

# # Figure out which branches should be copied over to which ones
# branch_map_inds = [
#     map_branch_def(bd, branch_defs_01) for bd in branch_defs_015
# ]

# # For each submodule, copy over the weights
# # Careful: this assumes a particular kind of architecture!
# modules_1 = dict(branched_model_1.named_children())
# modules_2 = dict(branched_model_2.named_children())

# for module_name in ["layers", "time_embedders"]:
#     for submodule_i, submodule in enumerate(modules_1[module_name]):
#         if len(submodule) == 1:
#             branched_model_2.get_submodule(module_name)[submodule_i].load_state_dict(
#                 submodule.state_dict()
#             )
#         elif len(submodule) == len(branch_defs_01):
#             target_submodule_list = branched_model_2.get_submodule(module_name)[submodule_i]
#             for target_i, source_i in enumerate(branch_map_inds):
#                 if source_i != -1:
#                     target_submodule_list[target_i].load_state_dict(
#                         submodule[source_i].state_dict()
#                     )
#                 else:
#                     # Copy over some other branch for a warm start
#                     # We'll manually set it for now (TODO)
#                     source_i = -1  # Last branch
#                     target_submodule_list[target_i].load_state_dict(
#                         submodule[source_i].state_dict()
#                     )
#         else:
#             raise ValueError("Found module list of length %d" % len(module_list))

# submodule = branched_model_1.get_submodule("last_linears")
# target_submodule_list = branched_model_2.get_submodule("last_linears")
# for target_i, source_i in enumerate(branch_map_inds):
#     if source_i != -1:
#         target_submodule_list[target_i].load_state_dict(
#             submodule[source_i].state_dict()
#         )
#     else:
#         # Copy over some other branch for a warm start
#         # We'll manually set it for now (TODO)
#         source_i = -1  # Last branch
#         target_submodule_list[target_i].load_state_dict(
#             submodule[source_i].state_dict()
#         )

In [12]:
# # Train the model, for the specific branches only

# # Freeze all shared layers of the model, and freeze all task-specific
# # layers other than the ones we want to train
# for module_name in ["layers", "time_embedders"]:
#     for submodule in branched_model_2.get_submodule(module_name):
#         if len(submodule) == 1:
#             for p in submodule.parameters():
#                 p.requires_grad = False
#         elif len(submodule) == len(branch_defs_015):
#             for i in range(len(submodule)):
#                 if branch_map_inds[i] != -1:
#                     for p in submodule[i].parameters():
#                         p.requires_grad = False
#                 else:
#                     for p in submodule[i].parameters():
#                         p.requires_grad = True
#         else:
#             raise ValueError("Found module list of length %d" % len(submodule))
# submodule = branched_model_2.get_submodule("last_linears")
# for i in range(len(submodule)):
#     if branch_map_inds[i] != -1:
#         for p in submodule[i].parameters():
#             p.requires_grad = False
#     else:
#         for p in submodule[i].parameters():
#             p.requires_grad = True

# # Train
# train_continuous_model.train_ex.run(
#     "train_branched_model",
#     config_updates={
#         "model": branched_model_2,
#         "sde": sde,
#         "data_loader": data_loader_5,
#         "class_time_to_branch_index": lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_015),
#         "num_epochs": 50,
#         "learning_rate": 0.001,
#         "t_limit": branch_defs_5[0][2],
#         "loss_weighting_type": "empirical_norm"
#     }
# )

In [13]:
branched_model_2 = model_util.load_model(
    scrna_ae.MultitaskResNet,
    os.path.join(models_base_path, "extension/4/last_ckpt.pth")
).to(DEVICE)

In [38]:
# Generate the samples
branched_samples = {}
for class_to_sample in classes_015:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_branched_samples(
        branched_model_2, sde, class_to_sample,
        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_015),
        sampler="pc", t_limit=t_limit, num_samples=1000
    )
    # branched_samples[class_to_sample] = dataset_015.decode_batch(sample).cpu().numpy()
    branched_samples[class_to_sample] = sample.cpu().numpy()

Sampling class: 0
Sampling class: 1
Sampling class: 5


#### Train label-guided model with only new label

In [9]:
# Import the label-guided model
label_guided_model_1 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "labelguided/01/2/last_ckpt.pth")
).to(DEVICE)

In [None]:
# Train on only new label
train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": label_guided_model_1,
        "sde": sde,
        "data_loader": data_loader_5,
        "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_015),
        "num_epochs": 50,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

INFO - train - Running command 'train_label_guided_model'
INFO - train - Started run with ID "9"
Loss: 22.94: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.81it/s]


Epoch 1 average Loss: 23.61


Loss: 17.10: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.73it/s]


Epoch 2 average Loss: 18.59


Loss: 17.50: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.88it/s]


Epoch 3 average Loss: 18.74


Loss: 16.87: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.83it/s]


Epoch 4 average Loss: 18.21


Loss: 19.30: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.89it/s]


Epoch 5 average Loss: 17.66


Loss: 19.93: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.84it/s]


Epoch 6 average Loss: 17.30


Loss: 15.83: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.82it/s]


Epoch 7 average Loss: 17.50


Loss: 17.16: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.84it/s]


Epoch 8 average Loss: 17.07


Loss: 18.03: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.81it/s]


Epoch 9 average Loss: 17.84


Loss: 20.83: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.82it/s]


Epoch 10 average Loss: 17.63


Loss: 16.12: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.86it/s]


Epoch 11 average Loss: 16.80


Loss: 19.43: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.96it/s]


Epoch 12 average Loss: 16.72


Loss: 16.31: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.91it/s]


Epoch 13 average Loss: 17.59


Loss: 17.82: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.91it/s]


Epoch 14 average Loss: 17.33


Loss: 14.72: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.92it/s]


Epoch 15 average Loss: 16.63


Loss: 18.18: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.99it/s]


Epoch 16 average Loss: 16.92


Loss: 18.80: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.91it/s]


Epoch 17 average Loss: 16.72


Loss: 17.24: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.98it/s]


Epoch 18 average Loss: 16.70


Loss: 18.27: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.91it/s]


Epoch 19 average Loss: 16.94


Loss: 15.59: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.84it/s]


Epoch 20 average Loss: 16.56


Loss: 18.46: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.87it/s]


Epoch 21 average Loss: 16.63


Loss: 14.99: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.89it/s]


Epoch 22 average Loss: 16.90


Loss: 15.84: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.99it/s]


Epoch 23 average Loss: 16.58


Loss: 17.00: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.88it/s]


Epoch 24 average Loss: 16.62


Loss: 20.26: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.80it/s]


Epoch 25 average Loss: 17.43


Loss: 16.05: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.75it/s]


Epoch 26 average Loss: 16.40


Loss: 19.32: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.86it/s]


Epoch 27 average Loss: 16.62


Loss: 16.60: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.85it/s]


Epoch 28 average Loss: 16.17


Loss: 17.04: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.92it/s]


Epoch 29 average Loss: 16.28


Loss: 19.52: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.74it/s]


Epoch 30 average Loss: 16.67


Loss: 16.49: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.82it/s]


Epoch 31 average Loss: 16.14


Loss: 17.93: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.87it/s]


Epoch 32 average Loss: 16.41


Loss: 15.48: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.84it/s]


Epoch 33 average Loss: 16.32


Loss: 18.01: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.77it/s]


Epoch 34 average Loss: 16.51


Loss: 16.62: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.79it/s]


Epoch 35 average Loss: 16.09


Loss: 15.58: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.74it/s]


Epoch 36 average Loss: 15.87


Loss: 15.98: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.81it/s]


Epoch 37 average Loss: 16.00


Loss: 17.53: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.78it/s]


Epoch 38 average Loss: 16.12


Loss: 20.07: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.88it/s]


Epoch 39 average Loss: 15.99


Loss: 18.36: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.85it/s]


Epoch 40 average Loss: 16.51


Loss: 16.15: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.78it/s]


Epoch 41 average Loss: 16.15


Loss: 16.36: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.90it/s]


Epoch 42 average Loss: 15.77


Loss: 16.45: 100%|█████████████████████████████████████████| 24/24 [00:08<00:00,  2.98it/s]


Epoch 43 average Loss: 16.43


Loss: 15.41:  25%|██████████▌                               | 6/24 [00:02<00:06,  2.77it/s]

In [16]:
linear_samples_newonly = {}
for class_to_sample in classes_015:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_1, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_015),
        sampler="pc", t_limit=t_limit, num_samples=1000
    )
    # linear_samples_newonly[class_to_sample] = dataset_015.decode_batch(sample).cpu().numpy()
    linear_samples_newonly[class_to_sample] = sample.cpu().numpy()

Sampling class: 1
Sampling class: 5


#### Train label-guided model with all data

In [15]:
# # Import the label-guided model
# label_guided_model_2 = model_util.load_model(
#     scrna_ae.LabelGuidedResNet,
#     os.path.join(models_base_path, "labelguided/01/2/last_ckpt.pth")
# ).to(DEVICE)

In [16]:
# # Train on all data
# train_continuous_model.train_ex.run(
#     "train_label_guided_model",
#     config_updates={
#         "model": label_guided_model_2,
#         "sde": sde,
#         "data_loader": data_loader_015,
#         "class_to_class_index": lambda c: class_to_class_index_tensor(c, classes_015),
#         "num_epochs": 30,
#         "learning_rate": 0.001,
#         "t_limit": t_limit,
#         "loss_weighting_type": "empirical_norm"
#     }
# )

INFO - train - Running command 'train_label_guided_model'
INFO - train - Started run with ID "6"
Loss: 18.18: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.91it/s]


Epoch 1 average Loss: 21.03


Loss: 18.02: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.91it/s]


Epoch 2 average Loss: 18.97


Loss: 19.17: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.94it/s]


Epoch 3 average Loss: 18.57


Loss: 17.08: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.92it/s]


Epoch 4 average Loss: 18.14


Loss: 20.54: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.98it/s]


Epoch 5 average Loss: 18.33


Loss: 18.32: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 6 average Loss: 17.83


Loss: 16.66: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 7 average Loss: 17.66


Loss: 17.50: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.96it/s]


Epoch 8 average Loss: 17.46


Loss: 17.39: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.89it/s]


Epoch 9 average Loss: 17.39


Loss: 15.76: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.93it/s]


Epoch 10 average Loss: 16.87


Loss: 18.34: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.92it/s]


Epoch 11 average Loss: 16.83


Loss: 15.29: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.95it/s]


Epoch 12 average Loss: 16.62


Loss: 15.76: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.90it/s]


Epoch 13 average Loss: 16.25


Loss: 15.44: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.96it/s]


Epoch 14 average Loss: 16.36


Loss: 16.54: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 15 average Loss: 16.01


Loss: 15.48: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 16 average Loss: 16.13


Loss: 18.14: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 17 average Loss: 15.77


Loss: 18.70: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.92it/s]


Epoch 18 average Loss: 15.61


Loss: 15.44: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.91it/s]


Epoch 19 average Loss: 15.31


Loss: 15.82: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.92it/s]


Epoch 20 average Loss: 15.24


Loss: 18.29: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 21 average Loss: 14.96


Loss: 15.03: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.93it/s]


Epoch 22 average Loss: 14.93


Loss: 18.78: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.90it/s]


Epoch 23 average Loss: 14.90


Loss: 16.40: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.93it/s]


Epoch 24 average Loss: 14.65


Loss: 17.50: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.93it/s]


Epoch 25 average Loss: 14.66


Loss: 15.35: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.90it/s]


Epoch 26 average Loss: 14.79


Loss: 14.18: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.90it/s]


Epoch 27 average Loss: 14.67


Loss: 13.88: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.93it/s]


Epoch 28 average Loss: 14.82


Loss: 17.76: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:44<00:00,  2.93it/s]


Epoch 29 average Loss: 14.62


Loss: 14.46: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:45<00:00,  2.89it/s]


Epoch 30 average Loss: 14.28


INFO - train - Completed after 0:24:10


<sacred.run.Run at 0x2aaab4dca940>

In [31]:
label_guided_model_2 = model_util.load_model(
    scrna_ae.LabelGuidedResNet,
    os.path.join(models_base_path, "extension/6/last_ckpt.pth")
).to(DEVICE)

In [37]:
linear_samples = {}
for class_to_sample in classes_015:
    print("Sampling class: %s" % class_to_sample)
    sample = generate.generate_continuous_label_guided_samples(
        label_guided_model_2, sde, class_to_sample,
        lambda c: class_to_class_index_tensor(c, classes_015),
        sampler="pc", t_limit=t_limit, num_samples=1000
    )
    # linear_samples[class_to_sample] = dataset_015.decode_batch(sample).cpu().numpy()
    linear_samples[class_to_sample] = sample.cpu().numpy()

Sampling class: 0
Sampling class: 1
Sampling class: 5


#### Compute FIDs

In [13]:
# Sample digits from the original dataset
true_samples = {}
for class_to_sample in classes_015:
    print("Sampling class: %s" % class_to_sample)
    inds = np.where(dataset_015.cell_cluster == class_to_sample)[0]
    sample_inds = np.random.choice(inds, size=1000, replace=False)
    samples = dataset_015.encode_batch(torch.tensor(dataset_015.data[sample_inds], device=DEVICE)).cpu().numpy()
    # samples = dataset_015.data[sample_inds]
    true_samples[class_to_sample] = samples

Sampling class: 0
Sampling class: 1
Sampling class: 5


In [48]:
branched_fids = {}
linear_fids = {}
linear_newonly_fids = {}

for c in branched_samples.keys():
    branched_fids[c] = fid.compute_fid(branched_samples[c], true_samples[c])
for c in linear_samples.keys():
    linear_fids[c] = fid.compute_fid(linear_samples[c], true_samples[c])
for c in linear_samples_newonly.keys():
    linear_newonly_fids[c] = fid.compute_fid(linear_samples_newonly[c], true_samples[c])

In [49]:
print(branched_fids)
print(linear_fids)
print(linear_newonly_fids)

{0: 1332.232965549725, 1: 1455.5579325521771, 5: 837.9917418850766}
{0: 1539.2042545417362, 1: 1501.6442270218297, 5: 1543.5487609656486}
{0: 1904.5301412384847, 1: 1898.953339182816, 5: 1876.229466558976}
