In [3]:
import numpy as np 
import matplotlib.pyplot as plt 
import torch 
import argparse 
from diffusion_models.diffusion.ddpm_lightning import DDPM  
from diffusion_models.diffusion.denoising.unet import UNet 
from tqdm import trange, tqdm 
from torch import nn 
import os 
import sys 
sys.path.insert(0, "../")
from model_builder import get_pretrained_model_v2
from datasets import get_dataset 

DATASET_PATH = "/home-local/Frederic/Datasets/FLCDataset/dataset-250k.tar"
MODEL = "mae-lightning-small"
WEIGHTS = "MAE_SMALL_STED"
CHECKPOINT = "/home-local/Frederic/baselines/DiffusionModels/latent-guidance/MAE_SMALL_STED"
TIMESTEPS = 1000
DATASET = "STED"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [6]:
latent_encoder, model_config = get_pretrained_model_v2(
    name=MODEL,
    weights=WEIGHTS,
    path=None,
    mask_ratio=0.0,
    pretrained=False,
    in_channels=1,
    as_classifier=True,
    blocks="all",
    num_classes=4,
)

denoising_model = UNet(
    dim=64,
    channels=1,
    cond_dim=model_config.dim,
    dim_mults=(1,2,4),
    condition_type="latent",
    num_classes=4,
)

model = DDPM(
    denoising_model=denoising_model,
    timesteps=TIMESTEPS,
    beta_schedule="linear",
    condition_type="latent",
    latent_encoder=latent_encoder
)
checkpoint = torch.load(os.path.join(CHECKPOINT, "checkpoint-69.pth"))
model.load_state_dict(checkpoint["state_dict"])
model.to(DEVICE)
model.eval()




mask_ratio 0.0
pretrained False
in_channels 1
blocks all
num_classes 4
--- mae-lightning-small | /home-local/Frederic/baselines/mae-small_STED/pl_checkpoint-999.pth ---

--- Loaded model mae-lightning-small with weights MAE_SMALL_STED ---
--- ViT case with none-ImageNet weights or from scratch ---
--- Freezing every parameter in mae-lightning-small ---
--- Added linear probe to all frozen blocks ---


DDPM(
  (model): UNet(
    (init_conv): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (time_mlp): Sequential(
      (0): SinusoidalPosEmb()
      (1): Linear(in_features=64, out_features=256, bias=True)
      (2): GELU(approximate='none')
      (3): Linear(in_features=256, out_features=256, bias=True)
    )
    (label_embed): Embedding(4, 256)
    (cond_mlp): Sequential(
      (0): Linear(in_features=384, out_features=256, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
    (downs): ModuleList(
      (0): ModuleList(
        (0-1): 2 x ResnetBlock(
          (mlp): Sequential(
            (0): SiLU()
            (1): Linear(in_features=256, out_features=128, bias=True)
          )
          (block1): Block(
            (proj): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): RMSNorm()
            (act): SiLU()
            (dropout): Dropout(p=0.0, inplace=Fal

In [7]:
dataset = get_dataset(DATASET, DATASET_PATH)

In [11]:
N = len(dataset)
indices = np.random.randint(0, N, size=20)
with torch.no_grad():
    for i in tqdm(indices, total=len(indices)):
        img = dataset[i]
        img = img.unsqueeze(0).to(DEVICE)
        condition = model.latent_encoder.forward_features(img)
        sample = model.p_sample_loop(shape=(img.shape[0], 1, img.shape[2], img.shape[3]), cond=condition, progress=True)
        sample = sample[:, [0], :, :].squeeze().cpu().detach().numpy()
        img = img[:, [0], :, :].squeeze().cpu().detach().numpy()

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(img, cmap="gray")
        axs[1].imshow(sample, cmap="gray")
        for ax in axs:
            ax.axis("off")
        plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
        fig.savefig(f"./quick_gens/dataset250k_{i}.pdf", dpi=1200, bbox_inches="tight")
        plt.close(fig)
        

Iterative sampling...: 100%|██████████| 1000/1000 [00:36<00:00, 27.08it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.93it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.88it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.81it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.73it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.67it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.60it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.58it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.53it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.52it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.48it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.45it/s]
Iterative sampling...: 100%|██████████| 1000/1000 [00:37<00:00, 26.43it/s]
Iterative sampling...: 10

In [12]:
img = dataset[40923]
noisy = model.q_sample(img, torch.tensor([100]))
fig = plt.figure(figsize=(10, 5))
plt.imshow(noisy.squeeze().cpu().detach().numpy(), cmap="gray")
plt.axis("off")
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
fig.savefig("./quick_gens/dataset250k_noise100_40923.pdf", dpi=1200, bbox_inches="tight")
plt.close(fig)
