In [1]:
import os
import random

import torch
import torch.nn as nn

import pandas as pd
import numpy as np

In [2]:
from src.models.mae import MaskedAutoencoderViT
from src.utils.misc import create_dataset

In [3]:
from monai import data
from monai import transforms

In [4]:
def init_seed(seed):
    random_seed = seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

In [5]:
init_seed(42)

In [6]:
roi = [128, 128, 128]

trans = transforms.Compose(
    [
        transforms.LoadImaged(
            keys=["image", "label"], 
            image_only=False,
            allow_missing_keys=True,
        ),
        transforms.EnsureChannelFirstd(
            keys=["image", "label"],
            allow_missing_keys=True,
        ),
        transforms.Orientationd(
            keys=["image","label"], 
            axcodes="RAS",
            allow_missing_keys=True,
        ),
        transforms.Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
            allow_missing_keys=True,
        ),
        transforms.ScaleIntensityRanged(
            keys=["image", "label"],
            a_min=40-150,
            a_max=40+150,
            b_min=0.0,
            b_max=1.0,
            clip=True,
            allow_missing_keys=True,
        ),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key="image",
            allow_smaller=False,
            allow_missing_keys=True,
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=(roi[0], roi[1], roi[2]),
            random_center=True,
            random_size=False,
            allow_missing_keys=True,
        ),
        transforms.ResizeWithPadOrCropd(
            keys=["image","label"],
            spatial_size=(roi[0], roi[1], roi[2]),
            method='symmetric',
            mode='constant',
            value=0,
            allow_missing_keys=True,
        ),
        transforms.ToTensord(
            keys=["image", "label"],
            allow_missing_keys=True,
        ),
    ]
)

In [7]:
batch_size = 1

test_csv_path = '/gpfs/data/denizlab/Users/hh2740/git_backups/MedSSL-3D/datasets/debug.csv'

# Load Data
df_test = pd.read_csv(test_csv_path)

img_test = list(df_test['img_path'])

test_files = create_dataset(img_test, None)

In [8]:
test_ds = data.Dataset(
    data=test_files, 
    transform=trans,
)

test_loader = data.DataLoader(
    dataset=test_ds,
    batch_size=batch_size,
    num_workers=2,
    pin_memory=True,
)

In [9]:
from src.utils.pos_embed import interpolate_pos_embed

In [10]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
model = MaskedAutoencoderViT(
    input_size=128,
    patch_size=16,
    mask_ratio=0.50,
    in_chans=1,
    dropout_rate=0.,
    spatial_dims=3,
    patch_embed='conv',
    pos_embed='sincos',
    encoder_depth=12,
    encoder_embed_dim=768,
    encoder_mlp_dim=3072,
    encoder_num_heads=12,
    decoder_depth=8,
    decoder_embed_dim=768,
    decoder_mlp_dim=3072,
    decoder_num_heads=16,
    norm_pix_loss=False,
    use_bias=True,
    use_flash_attn=True,
).to(device)

In [12]:
model_path = '/gpfs/data/denizlab/Users/hh2740/git_backups/MedSSL-3D/model_saved/mae_full_lr1.6e-3_mask0.75_sincos_pflash_ep1600_v2_gpu4_s42.pt'

# Load model with wrong size weights unloaded
if model_path != None:
    loaded_state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
    current_model_dict = model.state_dict()
    new_state_dict = {k:v if v.size()==current_model_dict[k].size() else current_model_dict[k] \
                    for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
    msg = model.load_state_dict(new_state_dict, strict=False)
    print(f"Load Pretrained Model: {msg}")
    # interpolate position embedding
    interpolate_pos_embed(model, new_state_dict)

Load Pretrained Model: <All keys matched successfully>


In [13]:
batch_data = next(iter(test_loader))

In [14]:
with torch.no_grad():
    with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16):
        x = batch_data['image'].to(device)
        loss, y, mask = model(x)

In [15]:
y.shape, mask.shape

(torch.Size([1, 512, 4096]), torch.Size([1, 512]))

In [16]:
y_unpatch = model.unpatchify(y, x)
y_unpatch = torch.einsum('nchwd->nhwdc', y_unpatch).detach().cpu()

mask = mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.out_chans)

mask_unpatch = model.unpatchify(mask, x)  # 1 is removing, 0 is keeping
mask_unpatch = torch.einsum('nchwd->nhwdc', mask_unpatch).detach().cpu()

x_ori = torch.einsum('nchwd->nhwdc', x).detach().cpu()

# masked image
im_masked = x_ori * (1 - mask_unpatch)

# MAE reconstruction pasted with visible patches
im_paste = x_ori * (1 - mask_unpatch) + y_unpatch * mask_unpatch

In [17]:
x_ori.shape, y_unpatch.shape, mask_unpatch.shape, im_masked.shape, im_paste.shape

(torch.Size([1, 128, 128, 128, 1]),
 torch.Size([1, 128, 128, 128, 1]),
 torch.Size([1, 128, 128, 128, 1]),
 torch.Size([1, 128, 128, 128, 1]),
 torch.Size([1, 128, 128, 128, 1]))

In [18]:
ori = np.array(x_ori[0].squeeze())
masked = np.array(im_masked[0].squeeze())
recon = np.array(y_unpatch[0].squeeze())
recon_vis = np.array(im_paste[0].squeeze())

In [19]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib.animation import FuncAnimation


def visualize_mae(ori, masked, recon, recon_vis, max_slices=64):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16, 8))

    # Set titles for each axis (subplot)
    ax1.set_title("original")
    ax2.set_title("masked")
    ax3.set_title("reconstruction")
    ax4.set_title("reconstruction + visible")

    ax1.axis('off')
    ax2.axis('off')
    ax3.axis('off')
    ax4.axis('off')

    im1 = ax1.imshow(ori[:, :, 0], animated=True, cmap="gray")
    im2 = ax2.imshow(masked[:, :, 0], animated=True, cmap="gray")
    im3 = ax3.imshow(recon[:, :, 0], animated=True, cmap="gray")
    im4 = ax4.imshow(recon_vis[:, :, 0], animated=True, cmap="gray")

    # Slider setup
    max_slices = 64
    depth = ori.shape[2]
    step = 1 if max_slices is None else max(1, depth // max_slices)
    # Initialize the text label and store its reference
    slice_label1 = ax1.text(ori.shape[2]-10, ori.shape[1]-10, f"0/{depth}", 
                           ha="right", va="bottom", color="white", fontsize=8, weight="bold")
    slice_label2 = ax2.text(masked.shape[2]-10, masked.shape[1]-10, f"0/{depth}", 
                           ha="right", va="bottom", color="white", fontsize=8, weight="bold")
    slice_label3 = ax3.text(recon.shape[2]-10, recon.shape[1]-10, f"0/{depth}", 
                           ha="right", va="bottom", color="white", fontsize=8, weight="bold")
    slice_label4 = ax4.text(recon_vis.shape[2]-10, recon_vis.shape[1]-10, f"0/{depth}", 
                           ha="right", va="bottom", color="white", fontsize=8, weight="bold")

    def update(frame):
        im1.set_array(ori[:, :, frame])
        im2.set_array(masked[:, :, frame])
        im3.set_array(recon[:, :, frame])
        im4.set_array(recon_vis[:, :, frame])

        slice_label1.set_text(f"{frame}/{depth}")
        slice_label2.set_text(f"{frame}/{depth}")
        slice_label3.set_text(f"{frame}/{depth}")
        slice_label4.set_text(f"{frame}/{depth}")

        return [im1, im2, im3, im4]

    anim = FuncAnimation(fig, update, frames=range(0, depth, step), interval=200, blit=True)
    
    return anim

In [1]:
max_slices = 64
anim = visualize_mae(ori, masked, recon, recon_vis, max_slices=max_slices)

In [2]:
HTML(anim.to_jshtml())

In [22]:
# Save the animation
anim.save('./animation/sample_mask0.75_pflash.mp4', writer='ffmpeg', fps=15, dpi=300)