In [None]:
import os
import numpy as np
from tqdm import tqdm
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import utils as vutils
from gptnano.translator import VQGANTransformer
from utils import load_data, plot_images
import matplotlib.pyplot as plt
import imageio
import itertools
from contextlib import nullcontext

from metrics import Structural_Similarity
from scipy.ndimage.filters import gaussian_filter

## Load Models

In [None]:
class Args:
	num_codebook_vectors = 8192
	checkpoint_path_3d_vqgan = '' # Path to the best 3D VQ-GAN model checkpoint
	checkpoint_path_2d_vqgan = '' # Path to the best 2D VQ-GAN model checkpoint
	checkpoint_path_gpt = '' # Path to the best GPT model checkpoint (located under .../gpt_results/run_X/checkpoint/transformer_X_X.pt)
	pkeep = 0.5
	sos_token = 0
	block_size = 4096 + 256 * 2
	n_unmasked = 256 * 2 + 1
	device = "cuda:0"
	batch_size = 1
	epochs = 100
	learning_rate = 2.25e-05
	num_workers = 1
    
args = Args()

In [None]:
model = VQGANTransformer(args).to(device=args.device)

model.load_gpt(args, strict=True)

## Sample Images

In [None]:
train_dataloader, val_dataloader, test_dataloader = load_data(args)

index= 192
data = next(itertools.islice(iter(test_dataloader), index, None))

imgs_ct = data['indices_ct']
imgs_ap = data['indices_ap']
imgs_lat = data['indices_lat']

imgs_ct = imgs_ct.to(device=args.device)
imgs_ap = imgs_ap.to(device=args.device)
imgs_lat = imgs_lat.to(device=args.device)

orig_ct = np.load(data['file_name'][0])

In [None]:
device=args.device
dtype = 'bfloat16'

# for later use in torch.autocast
device_type = 'cuda' if 'cuda' in device else 'cpu'
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32,
            'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
    device_type=device_type, dtype=ptdtype)

with ctx:
    log, sampled_imgs_ct, sampled_imgs_ap, sampled_imgs_lat = model.log_images(
                        (imgs_ct[0][None], imgs_ap[0][None], imgs_lat[0][None]), temperature=1.0, top_k=100)

sampled_imgs_ct = sampled_imgs_ct.float()
sampled_imgs_ap = sampled_imgs_ap.float()
sampled_imgs_lat = sampled_imgs_lat.float()

In [None]:

for slide in range(0, len(orig_ct), 10):
	# Plot the results side by side
	images = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0][slide], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0][slide], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0][slide], 'Original CT': orig_ct[slide]}
	fig, ax = plt.subplots(1, len(images), figsize=(20, 20))
	for i, (title, image) in enumerate(images.items()):
		ax[i].imshow(image, cmap='gray')
		ax[i].axis('off')
		ax[i].set_title(title)
	plt.show()

In [None]:
# Plot the results side by side

images = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0], 'Original CT': orig_ct}
fig, ax = plt.subplots(1, 2 * len(images), figsize=(20, 20))
i = 0
for title, image in images.items():
	ax[i].imshow(np.flip(np.mean(np.array(image), axis=1), 0), cmap='gray')
	ax[i].axis('off')
	ax[i].set_title(title)

	ax[i + 1].imshow(np.flip(np.mean(np.array(image), axis=2), 0), cmap='gray')
	ax[i + 1].axis('off')
	ax[i + 1].set_title(title)
	i+=2
plt.show()

In [None]:
# Plot the results side by side
SLICE_IDX = 60

images = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0], 'Original CT': orig_ct}
fig, ax = plt.subplots(len(images), 3, figsize=(20, 20))
i = 0
for title, image in images.items():
	ax[i][0].imshow(np.flip(np.array(image[:, SLICE_IDX, :]), 0), cmap='gray')
	ax[i][0].axis('off')
	ax[i][0].set_title(title)

	ax[i][1].imshow(np.flip(np.array(image[:, :, SLICE_IDX]), 0), cmap='gray')
	ax[i][1].axis('off')
	ax[i][1].set_title(title)

	ax[i][2].imshow(np.flip(np.array(image[SLICE_IDX, :, :]), 0), cmap='gray')
	ax[i][2].axis('off')
	ax[i][2].set_title(title)
	i += 1
plt.show()