In [1]:
import os

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from lib.model.cgan import Generator
from lib.dataset.cataracts_dataset import CATARACTSDataset
from lib.utils.pre_train import get_configs

In [2]:
DATA_PATH = '/media/yannik/samsung_data_ssd/data/CATARACTS-videos-processed/'
LOG_PATH = 'results/cgan/2023.02.26 20_57_53/'
TARGET_PATH = os.path.join(LOG_PATH, "eval/qual_samples/")
DEV = 'cuda'
data_conf, model_conf, diffusion_conf, train_conf = get_configs(LOG_PATH + "config.yaml")
os.makedirs(TARGET_PATH, exist_ok=True)
BATCH_SIZE = 1
STEPS = 1  # 30000//BATCH_SIZE
TARGET_SHAPE = (3, 270, 480)
print(f"Avail. GPUs: ", torch.cuda.device_count())

Avail. GPUs:  1


In [4]:
test_ds = CATARACTSDataset(
    root=DATA_PATH,
    resize_shape=eval(data_conf['SHAPE'])[1:],
    normalize=eval(data_conf['NORM']),
    mode='test',
    frame_step=data_conf['FRAME_STEP'],
    sample_img=False
)
# TODO: Weighted sampling / sampling from p(toolset|phase)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=8,
                     drop_last=True, shuffle=True, pin_memory=False)
print(f"{len(test_ds)} test samples")

35492 test samples


In [5]:
netG = Generator(
    label_embedding_dim=model_conf['EMBED_DIM'],
    latent_dim=model_conf['LATENT_DIM'],
    base_hidden_dim=model_conf['BASE_HIDDEN_DIM'],
    n_phase_classes=test_ds.num_phases_classes,
    n_tool_dims=test_ds.num_tool_classes
).to(DEV)
m = torch.nn.DataParallel(netG, device_ids=['cuda:0']) if not DEV == 'cpu' else netG
netG.load_state_dict(torch.load(LOG_PATH + "ckpt.pth", map_location='cpu')[0])
netG.eval()

RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "latent.1.weight", "latent.1.bias", "latent.1.running_mean", "latent.1.running_var", "model.5.weight", "model.5.bias", "model.5.running_mean", "model.5.running_var", "model.8.weight", "model.9.bias", "model.9.running_mean", "model.9.running_var", "model.13.weight", "model.13.bias", "model.13.running_mean", "model.13.running_var", "model.15.weight". 
	Unexpected key(s) in state_dict: "model.3.weight", "model.4.bias", "model.4.running_mean", "model.4.running_var", "model.4.num_batches_tracked", "model.6.weight", "model.7.weight", "model.7.bias", "model.7.running_mean", "model.7.running_var", "model.7.num_batches_tracked", "model.10.weight", "model.10.bias", "model.10.running_mean", "model.10.running_var", "model.10.num_batches_tracked". 
	size mismatch for phase_label_condition_generator.3.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
	size mismatch for phase_label_condition_generator.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for tool_label_condition_generator.2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
	size mismatch for tool_label_condition_generator.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for latent.0.weight: copying a param with shape torch.Size([8192, 512]) from checkpoint, the shape in current model is torch.Size([16384, 512]).
	size mismatch for latent.0.bias: copying a param with shape torch.Size([8192]) from checkpoint, the shape in current model is torch.Size([16384]).
	size mismatch for model.0.weight: copying a param with shape torch.Size([544, 1024, 4, 4]) from checkpoint, the shape in current model is torch.Size([1152, 1024, 4, 4]).
	size mismatch for model.4.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024, 512, 4, 4]).
	size mismatch for model.9.weight: copying a param with shape torch.Size([256, 128, 4, 4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for model.12.weight: copying a param with shape torch.Size([128, 3, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 128, 4, 4]).

## Fixed condition

In [9]:
from PIL import Image
import matplotlib.pyplot as plt
from lib.utils.misc import label_names_to_vectors

In [10]:
def sampling(phase_label: torch.Tensor, tool_label: torch.Tensor, steps: int):
    for i, (_, _, _, yp, ys) in enumerate(tqdm(test_dl)):
        with torch.no_grad():

            if i == steps:
                break

            eval_noise = torch.randn(BATCH_SIZE, model_conf['LATENT_DIM'], device=DEV)

            gen_sample = netG(eval_noise, phase_label, tool_label)
            N, C, H, W = gen_sample.shape

            gen = (gen_sample + 1.) * .5

            gen = F.interpolate(gen, size=TARGET_SHAPE[1:], mode='bilinear')
            gen = (gen * 255.).type(torch.uint8).squeeze(0)
            gen = gen.permute(1, 2, 0).cpu().numpy()

            im = Image.fromarray(gen)
            im.save(os.path.join(TARGET_PATH + f"{phase}_{toolset}_sample{i}.png"))

In [11]:
phase = 'Nucleus Breaking'
toolset = ['Phacoemulsifier Handpiece', 'Bonn Forceps']
phase_label, tool_label = label_names_to_vectors(phase, toolset, test_ds)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 10/35492 [00:01<1:10:57,  8.33it/s]


In [12]:
phase = 'Implant Ejection'
toolset = ['Capsulorhexis Forceps', 'Bonn Forceps']
phase_label, tool_label = label_names_to_vectors(phase, toolset, test_ds)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 10/35492 [00:00<36:19, 16.28it/s] 


In [13]:
phase = 'Suturing'
toolset = ['Vannas Scissors', 'Needle Holder']
phase_label, tool_label = label_names_to_vectors(phase, toolset, test_ds)
phase_label, tool_label = phase_label.long().to(DEV), tool_label.float().to(DEV)
sampling(phase_label, tool_label, steps=10)

  0%|          | 10/35492 [00:00<39:26, 14.99it/s] 


## Condition sampled from test-set

In [None]:
from lib.utils.misc import label_vectors_to_names

TARGET_PATH = os.path.join(LOG_PATH, "eval/gen_samples/")

for i, (_, _, file_name, phase_label, tool_label) in enumerate(tqdm(test_dl)):

    phase, toolset = label_vectors_to_names(test_ds, phase_label[0], tool_label[0])
    phase = phase.replace("/", " ")
    toolset = [tool.replace("/", " ") for tool in toolset]

    phase_label = phase_label.long().to(DEV)
    tool_label = tool_label.float().to(DEV)

    eval_noise = torch.randn(BATCH_SIZE, model_conf['LATENT_DIM'], device=DEV)

    gen_img = netG(eval_noise, phase_label, tool_label)
    gen_img = (gen_img + 1.) * .5
    gen = F.interpolate(gen_img, size=TARGET_SHAPE[1:], mode='bilinear')
    gen_img = (gen_img * 255.).type(torch.uint8).squeeze(0)
    gen_img = gen_img.permute(1, 2, 0).cpu().numpy()

    gen_img = Image.fromarray(gen_img)
    gen_img.save(os.path.join(TARGET_PATH + f"{phase}_{toolset}_sample{i}.png"))