In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTConfig, ViTModel

from src.datasets import UnetCustomDataset, unet_train_collate, unet_valid_collate
from src.plotters import visualize_predictions
from src.unetr_4x4 import CustomUNETR
from src.unetr_trainer import UNETR_TRAINER
from src.utils import select_device

In [None]:
# original VITMAE config (retrieve from logs/info)
vitmaeconfig = {
    "attention_probs_dropout_prob": 0.0,
    "decoder_hidden_size": 192,
    "decoder_intermediate_size": 768,
    "decoder_num_attention_heads": 6,
    "decoder_num_hidden_layers": 6,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.0,
    "hidden_size": 192,
    "image_size": 64,
    "initializer_range": 0.02,
    "intermediate_size": 768,
    "layer_norm_eps": 1e-05,
    "mask_ratio": 0.5,
    "model_type": "vit_mae",
    "norm_pix_loss": 1,
    "num_attention_heads": 6,
    "num_channels": 16,
    "num_hidden_layers": 6,
    "patch_size": 4,
    "qkv_bias": True,
    "transformers_version": "4.42.3"
}

# Configuration for VIT encoder for UNETR
vitconfig = {
    "hidden_size": vitmaeconfig["hidden_size"],
    "num_hidden_layers": vitmaeconfig["num_hidden_layers"],
    "num_attention_heads": vitmaeconfig["num_attention_heads"],
    "intermediate_size": vitmaeconfig["intermediate_size"],
    "hidden_act": vitmaeconfig["hidden_act"],
    "hidden_dropout_prob": vitmaeconfig["hidden_dropout_prob"],
    "attention_probs_dropout_prob": vitmaeconfig["attention_probs_dropout_prob"],
    "initializer_range": vitmaeconfig["initializer_range"],
    "layer_norm_eps": vitmaeconfig["layer_norm_eps"],
    "image_size": vitmaeconfig["image_size"],
    "patch_size": vitmaeconfig["patch_size"],
    "num_channels": vitmaeconfig["num_channels"],
    "qkv_bias": vitmaeconfig["qkv_bias"],
    "encoder_stride": vitmaeconfig["patch_size"],
}

'''Extracting Pretrained VITMAE Encoder'''
pretrained_model_path = "/home/mhill/Projects/cathepsin/logs/vitmae-grid/02/model.pth"
checkpoint = torch.load(pretrained_model_path)
vitmae_model = ViTMAEForPreTraining(config=ViTMAEConfig(**vitmaeconfig))
vitmae_model.load_state_dict(checkpoint['model_state_dict'])

'''Transfer to new VITCONFIG'''
vitmae_encoder = vitmae_model.vit
vit = ViTModel(config=ViTConfig(**vitconfig))
vit.load_state_dict(vitmae_encoder.state_dict(), strict=False)

In [None]:
import numpy as np

''' Loading Train Data '''

train_data = np.load("/home/mhill/Projects/cathepsin/data/unet_training_dataset.npz")
train_images = train_data['images']
train_labels = train_data['labels']
train_dataset = UnetCustomDataset(train_images, train_labels)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=16,
    collate_fn=unet_train_collate
)

''' Load Valid Data '''

valid_data = np.load('/home/mhill/Projects/cathepsin/data/unet_validation_dataset.npz')
valid_images = valid_data['images']
valid_labels = valid_data['labels']
valid_dataset = UnetCustomDataset(valid_images, valid_labels)

valid_dataloader = DataLoader(
    dataset=valid_dataset,
    batch_size=16,
    collate_fn=unet_valid_collate
)

In [None]:
device = select_device()
unet_model = CustomUNETR(encoder=vit, num_classes=16, feature_size=32).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3)

trainer = UNETR_TRAINER(model=unet_model,
                        optimizer=optimizer,
                        criterion=criterion,
                        device='cuda')

model = trainer.fit(num_epochs=5,
                    train_batches=train_dataloader,
                    valid_batches=valid_dataloader,
                    train_eval_batches=train_dataloader)

In [None]:
unet_model.eval()
with torch.inference_mode():
    images, labels = next(iter(valid_dataloader))
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    predictions = torch.argmax(outputs, dim=1)

    # Convert to numpy arrays for visualization
    images_np = images.cpu().numpy()
    labels_np = labels.cpu().numpy()
    predictions_np = predictions.cpu().numpy()
    print(
        f"Shape of image_np : {images_np.shape} | Label_np : {labels_np.shape} | Preidcitons : {predictions_np.shape}")
    # Visualize the predictions
    visualize_predictions(5, images_np, labels_np, predictions_np)