In [None]:
# Set environment as current working directory
import sys
sys.path.append('..')

from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt

import torch
from torch.utils.data import ConcatDataset
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint


from datasets.data_module import SimpleDataModule
from datasets import LIDCDataset
from vqgan.model import VQVAE, VQGAN, VAE, VAEGAN

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import os
import numpy as np
from tqdm import tqdm

In [None]:
# NOTE: This Notebook has to be run three times using different settings for the parameters below.
# First Setting: Choose USE_2D = False
# Second Setting: Choose USE_2D = True and PROJECTION_PLANE = 'lat'
# Third Setting: Choose USE_2D = True and PROJECTION_PLANE = 'ap'

USE_2D = False # True or False
PROJECTION_PLANE = 'lat' # 'ap' or 'lat'
PATH_TO_PREPROCESSED_DATA = '' # Replace this with the folder containing the preprocessed LIDC-IDRI dataset (i.e., <PATH_TO_PREPROCESSED_DATA>)
BEST_VQ_GAN_CKPT_2D = '' # Replace this with the best VQ-GAN checkpoint for the 2D model
BEST_VQ_GAN_CKPT_3D = '' # Replace this with the best VQ-GAN checkpoint for the 3D model
STORAGE_DIR = '' # Replace this with the desired path for storing the indices (e.g. /data/lidc_indices/)

In [None]:
def create_dir(dir_path):
	if not os.path.exists(dir_path):
		os.makedirs(dir_path)

In [None]:
gpus = [0] if torch.cuda.is_available() else None

In [None]:
if USE_2D == True:
    lidc_dataset_train = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='train', projection_plane=PROJECTION_PLANE)

    lidc_dataset_val = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='val', projection_plane=PROJECTION_PLANE)

    lidc_dataset_test = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='test', projection_plane=PROJECTION_PLANE)

    dm = SimpleDataModule(
        ds_train=lidc_dataset_train,
        ds_val=lidc_dataset_val,
        ds_test=lidc_dataset_test,
        batch_size=1,
        num_workers=1,
        pin_memory=True
    )
else:
    lidc_dataset_train = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='train')

    lidc_dataset_val = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='val')

    lidc_dataset_test = LIDCDataset(
        root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='test')

    dm = SimpleDataModule(
        ds_train=lidc_dataset_train,
        ds_val=lidc_dataset_val,
        ds_test=lidc_dataset_test,
        batch_size=1,
        num_workers=30,
        pin_memory=True
    )

In [None]:
if USE_2D:
    model = VQGAN(
        in_channels=1,
        out_channels=1,
        emb_channels=512,
        num_embeddings=8192,
        spatial_dims=2,
        hid_chs=[64, 128, 256, 512],
        kernel_sizes=[3,  3, 3, 3],
        strides=[1, 2, 2, 2],
        embedding_loss_weight=1,
        beta=1,
        pixel_loss=torch.nn.L1Loss,
        deep_supervision=1,
        use_attention='none',
        sample_every_n_steps=50,
    )

    model.load_pretrained(BEST_VQ_GAN_CKPT_2D)
else:
    model = VQGAN(
    in_channels=1,
    out_channels=1,
    emb_channels=256,
    num_embeddings=8192,
    spatial_dims=3,
    hid_chs=[32, 64,  128, 256],
    kernel_sizes=[3,  3,   3, 3],
    strides=[1,  2,   2, 2],
    embedding_loss_weight=1,
    beta=1,
    pixel_loss=torch.nn.L1Loss,
    deep_supervision=0,
    use_attention='none',
    norm_name=("GROUP", {'num_groups': 4, "affine": True}),
    sample_every_n_steps=200,
    )

    model.load_pretrained(BEST_VQ_GAN_CKPT_3D)

model.eval()

In [None]:
# get next element of dataloader
test_sample = next(iter(dm.test_dataloader()))

In [None]:
SLICE_NUM = 60

if USE_2D:
	plt.imshow(test_sample['source'][0][0], cmap='gray')
	plt.axis('off')
else:
	plt.imshow(test_sample['source'][0][0][SLICE_NUM], cmap='gray')
	plt.axis('off')

In [None]:
out_sample = model(test_sample['source'])

In [None]:
if USE_2D:
	plt.imshow(out_sample[0][0][0].detach().cpu(), cmap='gray')
	plt.axis('off')
else:
	plt.imshow(out_sample[0][0][0][SLICE_NUM].detach().cpu(), cmap='gray')
	plt.axis('off')
    

In [None]:
indices, embedding_shape = model.vqvae.encode_to_indices(test_sample['source'])

In [None]:
print(indices.shape)

In [None]:
print(embedding_shape)

In [None]:
out_sample_2 = model.vqvae.decode_from_indices(indices, embedding_shape)

if USE_2D:
	plt.imshow(out_sample_2[0][0].detach().cpu(), cmap='gray')
	plt.axis('off')
else:
	plt.imshow(out_sample_2[0][0][SLICE_NUM].detach().cpu(), cmap='gray')
	plt.axis('off')

# Convert all images to indices

In [None]:
# get next element of dataloader
storage_dir = STORAGE_DIR 
train_path = os.path.join(storage_dir, 'train') 
val_path = os.path.join(storage_dir, 'val')
test_path = os.path.join(storage_dir, 'test')
create_dir(train_path)
create_dir(val_path)
    
for split in [[train_path, dm.train_dataloader()], [val_path, dm.val_dataloader()], [test_path, dm.test_dataloader()]]:
	for sample in tqdm(split[1]):
		indices, embedding_shape = model.vqvae.encode_to_indices(sample['source'])
		file_name = sample['file_name'][0].split('/')[-2] 
		indices_np = indices.detach().cpu().numpy()
		folder_path = os.path.join(split[0], file_name)
		create_dir(folder_path)
		if USE_2D:
			np.save(os.path.join(folder_path, f'{PROJECTION_PLANE}.npy'), indices_np)
		else:
			np.save(os.path.join(folder_path, 'CT.npy'), indices_np)