In [25]:
import os

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from lib.model.vqvae2 import VQVAE
from lib.model.pixelsnail_prior import PixelSNAIL
from lib.dataset.cataracts_dataset import CATARACTSDataset
from lib.utils.pre_train import get_configs
from lib.utils.misc import WrappedModel, sample_model

In [26]:
DATA_PATH = '/media/yannik/samsung_data_ssd/data/CATARACTS-videos-processed/'
VQVAE_PATH = 'results/vqvae2/2023.02.25 14_51_30/'
TOP_PATH = 'results/top_prior/2023.02.28 07_31_36/'
BOTOM_PATH = 'results/bottom_prior/2023.03.01 09_41_01/'
TARGET_PATH = 'results/vqvae2_qual_samples/'
DEV = 'cuda'
data_conf, vqvae_model_conf, _, _ = get_configs(VQVAE_PATH + "config.yaml")
_, bottom_model_conf, _, _ = get_configs(BOTOM_PATH + "config.yaml")
_, top_model_conf, _, _ = get_configs(TOP_PATH + "config.yaml")
os.makedirs(TARGET_PATH, exist_ok=True)
BATCH_SIZE = 1
STEPS = 1  # 30000//BATCH_SIZE
TARGET_SHAPE = (3, 270, 480)
print(f"Avail. GPUs: ", torch.cuda.device_count())

Avail. GPUs:  1


In [27]:
test_ds = CATARACTSDataset(
    root=DATA_PATH,
    resize_shape=eval(data_conf['SHAPE'])[1:],
    normalize=eval(data_conf['NORM']),
    mode='test',
    frame_step=data_conf['FRAME_STEP'],
    sample_img=False
)
# TODO: Weighted sampling / sampling from p(toolset|phase)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=8,
                     drop_last=True, shuffle=True, pin_memory=False)
print(f"{len(test_ds)} test samples")

35492 test samples


In [28]:
vqvae = VQVAE(
    in_channel=eval(data_conf['SHAPE'])[0],
    channel=vqvae_model_conf['CHANNELS'],
    n_res_block=vqvae_model_conf['N_RES_BLOCKS'],
    n_res_channel=vqvae_model_conf['RES_CHANNELS'],
    embed_dim=vqvae_model_conf['EMBED_DIM'],
    n_embed=vqvae_model_conf['N_EMBEDDINGS'],
    decay=vqvae_model_conf['EMA_DECAY']
).to(DEV)
vqvae = torch.nn.DataParallel(vqvae, device_ids=['cuda:0']) if not DEV == 'cpu' else WrappedModel(vqvae)
try:
    vqvae.module.load_state_dict(torch.load(VQVAE_PATH + "ckpt.pth", map_location='cpu')[0])
except:
    vqvae.load_state_dict(torch.load(VQVAE_PATH + "ckpt.pth", map_location='cpu')[0])
vqvae.eval()

DataParallel(
  (module): VQVAE(
    (enc_b): Encoder(
      (blocks): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ResBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU(inplace=True)
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (6): ResBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU(inplace=True)
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (7): ReLU(inplace=True)
      )


In [29]:
top_prior = PixelSNAIL(
    shape=[16, 16],
    n_class=top_model_conf['N_CLASS'],
    channel=top_model_conf['CHANNELS'],
    kernel_size=5,
    n_block=top_model_conf['N_BLOCKS'],
    n_res_block=top_model_conf['N_BLOCKS'],
    res_channel=top_model_conf['RES_CHANNELS'],
    dropout=top_model_conf['DROPOUT'],
    n_out_res_block=top_model_conf['N_OUT_RES_BLOCKS'],
    n_phase_labels=CATARACTSDataset.num_phases_classes,
    n_tool_labels=CATARACTSDataset.num_tool_classes,
    label_cond_ch=top_model_conf['LABEL_COND_CH']
).to(DEV)
top_prior = torch.nn.DataParallel(top_prior, device_ids=['cuda:0']) if not DEV == 'cpu' else WrappedModel(top_prior)
try:
    top_prior.module.load_state_dict(torch.load(TOP_PATH + "ckpt.pth", map_location='cpu')[0])
except Exception:
    top_prior.load_state_dict(torch.load(TOP_PATH + "ckpt.pth", map_location='cpu')[0])
top_prior.eval()

DataParallel(
  (module): PixelSNAIL(
    (horizontal): CausalConv2d(
      (pad): ZeroPad2d((2, 2, 1, 0))
      (conv): WNConv2d(
        (conv): Conv2d(512, 256, kernel_size=(2, 5), stride=(1, 1))
      )
    )
    (vertical): CausalConv2d(
      (pad): ZeroPad2d((1, 0, 2, 0))
      (conv): WNConv2d(
        (conv): Conv2d(512, 256, kernel_size=(3, 2), stride=(1, 1))
      )
    )
    (phase_label_emb): Sequential(
      (0): Embedding(19, 128)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (phase_label_conv): Sequential(
      (0): ConvTranspose2d(4, 2, kernel_size=(2, 2), stride=(2, 2))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (tool_label_emb): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=128, out_features=256,

In [30]:
bottom_prior = PixelSNAIL(
    shape=[32, 32],
    n_class=bottom_model_conf['N_CLASS'],
    channel=bottom_model_conf['CHANNELS'],
    kernel_size=5,
    n_block=bottom_model_conf['N_BLOCKS'],
    n_res_block=bottom_model_conf['N_RES_BLOCKS'],
    res_channel=bottom_model_conf['RES_CHANNELS'],
    attention=False,
    dropout=bottom_model_conf['DROPOUT'],
    n_cond_res_block=bottom_model_conf['N_COND_RES_BLOCKS'],
    cond_res_channel=bottom_model_conf['RES_CHANNELS'],
    n_phase_labels=CATARACTSDataset.num_phases_classes,
    n_tool_labels=CATARACTSDataset.num_tool_classes,
    label_cond_ch=bottom_model_conf['LABEL_COND_CH']
).to(DEV)
bottom_prior = torch.nn.DataParallel(bottom_prior, device_ids=['cuda:0']) if not DEV == 'cpu' else WrappedModel(bottom_prior)
try:
    bottom_prior.module.load_state_dict(torch.load(BOTOM_PATH + "ckpt.pth", map_location='cpu')[0])
except Exception:
    bottom_prior.load_state_dict(torch.load(BOTOM_PATH + "ckpt.pth", map_location='cpu')[0])
bottom_prior.eval()

DataParallel(
  (module): PixelSNAIL(
    (horizontal): CausalConv2d(
      (pad): ZeroPad2d((2, 2, 1, 0))
      (conv): WNConv2d(
        (conv): Conv2d(512, 256, kernel_size=(2, 5), stride=(1, 1))
      )
    )
    (vertical): CausalConv2d(
      (pad): ZeroPad2d((1, 0, 2, 0))
      (conv): WNConv2d(
        (conv): Conv2d(512, 256, kernel_size=(3, 2), stride=(1, 1))
      )
    )
    (cond_resnet): CondResNet(
      (blocks): Sequential(
        (0): WNConv2d(
          (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): GatedResBlock(
          (activation): ELU(alpha=1.0)
          (conv1): WNConv2d(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (conv2): WNConv2d(
            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (gate): GLU(dim=1)
        )
        (2): GatedRes

## Fixed condition

In [31]:
from PIL import Image
import matplotlib.pyplot as plt
from lib.utils.misc import label_names_to_vectors

In [32]:
def sampling(phase_label: torch.Tensor,
             tool_label: torch.Tensor,
             steps: int,
             temp: float = 1.0,
             target_path: str = TARGET_PATH):
    with torch.no_grad():

        for i in tqdm(range(steps)):

            top_sample = sample_model(top_prior,
                                      DEV,
                                      BATCH_SIZE,
                                      [16, 16],
                                      temp,
                                      phase_label=phase_label,
                                      tool_label=tool_label)
            bottom_sample = sample_model(bottom_prior,
                                         DEV,
                                         BATCH_SIZE,
                                         [32, 32],
                                         temp,
                                         phase_label=phase_label,
                                         tool_label=tool_label,
                                         condition=top_sample)

            decoded_sample = vqvae.module.decode_code(top_sample, bottom_sample)
            decoded_sample = decoded_sample.clamp(-1, 1)

            gen = (decoded_sample + 1.) * .5

            gen = F.interpolate(gen, size=TARGET_SHAPE[1:], mode='bilinear')
            gen = (gen * 255.).type(torch.uint8).squeeze(0)
            gen = gen.permute(1, 2, 0).cpu().numpy()

            im = Image.fromarray(gen)
            im.save(os.path.join(TARGET_PATH + f"{phase}_{toolset}_sample{i}.png"))

In [33]:
phase = 'Nucleus Breaking'
toolset = ['Phacoemulsifier Handpiece', 'Bonn Forceps']
phase_label, tool_label = label_names_to_vectors(test_ds, phase, toolset)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/16 [00:00<?, ?it/s][A
  6%|▋         | 1/16 [00:00<00:08,  1.67it/s][A
 12%|█▎        | 2/16 [00:01<00:09,  1.49it/s][A
 19%|█▉        | 3/16 [00:01<00:08,  1.57it/s][A
 25%|██▌       | 4/16 [00:02<00:07,  1.64it/s][A
 31%|███▏      | 5/16 [00:03<00:06,  1.66it/s][A
 38%|███▊      | 6/16 [00:03<00:05,  1.68it/s][A
 44%|████▍     | 7/16 [00:04<00:05,  1.66it/s][A
 50%|█████     | 8/16 [00:04<00:04,  1.65it/s][A
 56%|█████▋    | 9/16 [00:05<00:04,  1.64it/s][A
 62%|██████▎   | 10/16 [00:06<00:03,  1.61it/s][A
 69%|██████▉   | 11/16 [00:06<00:03,  1.55it/s][A
 75%|███████▌  | 12/16 [00:07<00:02,  1.52it/s][A
 81%|████████▏ | 13/16 [00:08<00:02,  1.44it/s][A
 88%|████████▊ | 14/16 [00:09<00:01,  1.38it/s][A
 94%|█████████▍| 15/16 [00:09<00:00,  1.30it/s][A
100%|██████████| 16/16 [00:10<00:00,  1.47it/s][A

  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:01<00:36,  1.19s/it][A
  6%|▋         | 2/

In [34]:
phase = 'Implant Ejection'
toolset = ['Capsulorhexis Forceps', 'Bonn Forceps']
phase_label, tool_label = label_names_to_vectors(test_ds, phase, toolset)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/16 [00:00<?, ?it/s][A
  6%|▋         | 1/16 [00:00<00:08,  1.81it/s][A
 12%|█▎        | 2/16 [00:01<00:07,  1.83it/s][A
 19%|█▉        | 3/16 [00:01<00:07,  1.84it/s][A
 25%|██▌       | 4/16 [00:02<00:06,  1.84it/s][A
 31%|███▏      | 5/16 [00:02<00:06,  1.83it/s][A
 38%|███▊      | 6/16 [00:03<00:05,  1.83it/s][A
 44%|████▍     | 7/16 [00:03<00:05,  1.79it/s][A
 50%|█████     | 8/16 [00:04<00:04,  1.77it/s][A
 56%|█████▋    | 9/16 [00:05<00:03,  1.75it/s][A
 62%|██████▎   | 10/16 [00:05<00:03,  1.75it/s][A
 69%|██████▉   | 11/16 [00:06<00:02,  1.69it/s][A
 75%|███████▌  | 12/16 [00:06<00:02,  1.65it/s][A
 81%|████████▏ | 13/16 [00:07<00:01,  1.54it/s][A
 88%|████████▊ | 14/16 [00:08<00:01,  1.47it/s][A
 94%|█████████▍| 15/16 [00:09<00:00,  1.38it/s][A
100%|██████████| 16/16 [00:10<00:00,  1.60it/s][A

  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:26,  1.18it/s][A
  6%|▋         | 2/

In [35]:
phase = 'Suturing'
toolset = ['Vannas Scissors', 'Needle Holder']
phase_label, tool_label = label_names_to_vectors(test_ds, phase, toolset)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/16 [00:00<?, ?it/s][A
  6%|▋         | 1/16 [00:00<00:08,  1.68it/s][A
 12%|█▎        | 2/16 [00:01<00:08,  1.70it/s][A
 19%|█▉        | 3/16 [00:01<00:07,  1.71it/s][A
 25%|██▌       | 4/16 [00:02<00:06,  1.72it/s][A
 31%|███▏      | 5/16 [00:02<00:06,  1.71it/s][A
 38%|███▊      | 6/16 [00:03<00:05,  1.69it/s][A
 44%|████▍     | 7/16 [00:04<00:05,  1.67it/s][A
 50%|█████     | 8/16 [00:04<00:04,  1.65it/s][A
 56%|█████▋    | 9/16 [00:05<00:04,  1.63it/s][A
 62%|██████▎   | 10/16 [00:06<00:03,  1.62it/s][A
 69%|██████▉   | 11/16 [00:06<00:03,  1.56it/s][A
 75%|███████▌  | 12/16 [00:07<00:02,  1.52it/s][A
 81%|████████▏ | 13/16 [00:08<00:02,  1.42it/s][A
 88%|████████▊ | 14/16 [00:09<00:01,  1.37it/s][A
 94%|█████████▍| 15/16 [00:09<00:00,  1.27it/s][A
100%|██████████| 16/16 [00:10<00:00,  1.48it/s][A

  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:28,  1.10it/s][A
  6%|▋         | 2/

## Sample from every phase

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
for yp in range(10, test_ds.num_phases_classes):
    sampling(phase_label=torch.tensor([yp]).view(1, 1),
             tool_label=None,
             steps=10,
             target_path=os.path.join(TARGET_PATH, "results/vqvae2_qual_samples2/"),
             temp=1.0)