In [29]:
import numpy as np 
import matplotlib.pyplot as plt 
import argparse 
from diffusion_models.diffusion.ddpm_lightning import DDPM 
from diffusion_models.diffusion.denoising.unet import UNet 
import torch 
from tqdm import trange, tqdm 
import copy 
import sys 
import os 
import random 
from class_dict import class_dict
sys.path.insert(0, "../")
from loaders import get_dataset 
from model_builder import get_pretrained_model_v2 

DATASET_PATH = "/home-local/Frederic/evaluation-data/NeuralActivityStates"
LATENT_ENCODER = "mae-lightning-small"
WEIGHTS = "MAE_SMALL_JUMP"
CHECKPOINT = "/home-local/Frederic/baselines/DiffusionModels/classifier-guidance"
NUM_SAMPLES = 15 
GUIDANCE = "class"

def get_save_folder(key: str) -> str: 
    if key is None:
        return "from-scratch"
    elif "imagenet" in key.lower():
        return "ImageNet"
    elif "sted" in key.lower():
        return "STED"
    elif "jump" in key.lower():
        return "JUMP"
    elif "sim" in key.lower():
        return "SIM"
    elif "hpa" in key.lower():
        return "HPA"
    elif "sim" in key.lower():
        return "SIM"
    else:
        raise NotImplementedError("The requested weights do not exist.")

SAVENAME = get_save_folder(key=WEIGHTS)

In [30]:
np.random.seed(42)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_channels = 3 if SAVENAME == "ImageNet" else 1  
latent_encoder, model_config = get_pretrained_model_v2(
    name=LATENT_ENCODER,
    weights=WEIGHTS,
    path=None,
    mask_ratio=0.0,
    pretrained=True if n_channels == 3 else False,
    in_channels=n_channels,
    as_classifier=True,
    blocks="all",
    num_classes=4
)
denoising_model = UNet(
    dim=64,
    channels=1,
    dim_mults=(1,2,4),
    cond_dim=model_config.dim,
    condition_type=GUIDANCE,
    num_classes=24 if GUIDANCE == "class" else 4
)
model = DDPM(
    denoising_model=denoising_model,
    timesteps=1000,
    beta_schedule="linear",
    condition_type=GUIDANCE,
    latent_encoder=latent_encoder if GUIDANCE == "latent" else None,
)

path = f"{CHECKPOINT}/{WEIGHTS}/checkpoint-69.pth" if GUIDANCE == "latent" else f"{CHECKPOINT}/checkpoint-69.pth"
print(path)
ckpt = torch.load(path)
model.load_state_dict(ckpt["state_dict"])
model = model.to(DEVICE)

mask_ratio 0.0
pretrained False
in_channels 1
blocks all
num_classes 4
--- mae-lightning-small | /home-local/Frederic/baselines/mae-small_JUMP/checkpoint-999.pth ---

--- Loaded model mae-lightning-small with weights MAE_SMALL_JUMP ---
--- Freezing every parameter in mae-lightning-small ---
--- Added linear probe to all frozen blocks ---
/home-local/Frederic/baselines/DiffusionModels/classifier-guidance/checkpoint-69.pth


In [31]:
train_loader, valid_loader, test_loader = get_dataset(
    name="neural-activity-states",
    transform=None,
    training=True,
    path=None,
    batch_size=model_config.batch_size,
    n_channels=1,
    balance=False
)
dataset = test_loader.dataset

Processing dataset..: 100%|██████████| 6021/6021 [00:04<00:00, 1246.07it/s]
Processing dataset..: 100%|██████████| 1176/1176 [00:00<00:00, 1491.09it/s]
Processing dataset..: 100%|██████████| 1510/1510 [00:01<00:00, 1458.64it/s]


=== NAS dataset ===
(array([0, 1, 2, 3]), array([1545, 1423,  512, 2541]))
(array([0, 1, 2, 3]), array([230, 420, 113, 413]))
(array([0, 1, 2, 3]), array([492, 299, 200, 519]))
Training size: 6021
Validation size: 1176
Test size: 1510







In [32]:
def save_image(image: np.ndarray, generation: np.ndarray, i: int, class_name: str) -> None:
    fig = plt.figure()
    plt.imshow(image, cmap='hot', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"./classification-study/{GUIDANCE}-guidance/templates/template{i}_{class_name.upper()}.png", dpi=1200, bbox_inches="tight")
    plt.close(fig)


    weights = "classifier-guidance" if GUIDANCE == "class" else WEIGHTS
    fig = plt.figure()
    plt.imshow(generation, cmap='hot', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"./classification-study/{GUIDANCE}-guidance/candidates/{weights}_template{i}_{class_name.upper()}.png", dpi=1200, bbox_inches="tight")
    plt.close(fig)

In [33]:
indices = np.arange(len(dataset))

np.random.seed(42)
np.random.shuffle(indices)
from typing import Union
def denormalize(img: Union[np.ndarray, torch.Tensor], mu: float = 0.06957887037697921, std: float = 0.1254630260057964) -> Union[np.ndarray, torch.Tensor]:
    """
    Denormalizes an image. Note that the parameters mu and sigma seem hard-coded but they have been computed from the training sets and can be found
    in the attribute_datasets.py file.
    """
    return img * std + mu

counter = 0
model.eval()
with torch.no_grad():
    for idx in tqdm(indices, total=len(indices), desc="Processing samples..."):
        original_img, metadata = dataset[idx]
        protein = "psd95"
        if counter >= NUM_SAMPLES:
            break
        else:
            counter += 1

       
        if SAVENAME == "ImageNet":
                image = torch.tensor(original_img, dtype=torch.float32).repeat(3, 1, 1).unsqueeze(0).to(DEVICE)
                assert torch.equal(image[0, 0, :, :], image[0, 1, :, :]) and torch.equal(image[0, 1, :, :], image[0, 2, :, :]), "All three channels in the image tensor are not equal"
        else:
            image = torch.tensor(original_img, dtype=torch.float32).unsqueeze(0).to(DEVICE)

        condition = model.latent_encoder.forward_features(image) if GUIDANCE == "latent" else torch.tensor(class_dict[protein], dtype=torch.int8).to(DEVICE).long() 

        original_img = original_img[0]
        generation = model.p_sample_loop(shape=(image.shape[0], 1, image.shape[2], image.shape[3]), cond=condition, progress=True)

        original_img = original_img.cpu().numpy()    
        generation = generation.squeeze().cpu().numpy()
        if SAVENAME == "ImageNet":
            generation = denormalize(generation)
            m, M = generation.min(), generation.max()
            generation = (generation - m) / (M - m)
        
        # if SAVENAME == "ImageNet":
        #     print(original_img.min(), original_img.max())
        #     original_img = denormalize(original_img, mu=0.014, std=0.03)
        #     print(original_img.min(), original_img.max())
        #     print(generation.min(), generation.max())
        #     # generation = denormalize(generation)
        #     m, M = generation.min(), generation.max()
        #     print("\n\n")
        
        save_image(original_img, generation, idx, protein)


  image = torch.tensor(original_img, dtype=torch.float32).unsqueeze(0).to(DEVICE)
Iterative sampling...: 100%|██████████| 1000/1000 [00:36<00:00, 27.09it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.92it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.81it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.61it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.42it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.36it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.36it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.36it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:38<00:00, 26.29it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:38<00:00, 26.27it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:38<00:00, 26.26it/s]]
Iterative sampling...: 100%|██████████| 1000/1000 [00:38<00:00, 26.30it/s]]
Iterative sampli