In [None]:
import os
import sys
import shutil
import argparse
import IPython 
from PIL import Image, ImageFont, ImageDraw
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import matplotlib.pyplot as plt 
import torchxrayvision as xrv


from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
import dinov2.distributed as distributed
from dinov2.eval.metrics import MetricAveraging, build_metric, build_segmentation_metrics
from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen, SARSCoV2CT
from dinov2.data.datasets.medical_dataset import MedicalVisionDataset
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import (make_segmentation_train_transforms, make_classification_eval_transform, make_segmentation_eval_transforms,
                                    make_classification_train_transform)
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import (is_padded_matrix, ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features, collate_fn_3d,
                               make_datasets)
from dinov2.eval.classification.utils import LinearClassifier, create_linear_input, setup_linear_classifiers, AllClassifiers
from dinov2.eval.metrics import build_segmentation_metrics
from dinov2.eval.segmentation.utils import LinearDecoder, setup_decoders, DINOV2Encoder
from dinov2.utils import show_image_from_tensor

In [None]:
# args = argparse.AMOSace(backbone="dinov2", config_file='dinov2/configs/eval/vitl14_pretrain.yaml', pretrained_weights='/home/baharoon/models/dinov2_vitl14_pretrain.pth', output_dir='results/test', opts=[], train_dataset_str='NIHChestXray:split=TRAIN:root=/home/baharoon/Data/NIH', val_dataset_str='NIHChestXray:split=VAL:root=/home/baharoon/Data/NIH', test_dataset_str='NIHChestXray:split=TEST:root=/home/baharoon/Data/NIH', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
args = argparse.Namespace(backbone="dinov2", config_file='dinov2/configs/eval/vitl14_pretrain.yaml', pretrained_weights='/home/baharoon/models/dinov2_vitl14_pretrain.pth', output_dir='results/test', opts=[], train_dataset_str='MC:split=TRAIN:root=/home/baharoon/Data/MC', val_dataset_str='MC:split=VAL:root=/home/baharoon/Data/MC', test_dataset_str='MC:split=TEST:root=/home/baharoon/Data/MC', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
model, autocast_dtype = setup_and_build_model(args)
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)

In [None]:
train_image_transform, train_target_transform = make_segmentation_train_transforms()
eval_image_transform, eval_target_transform  = make_segmentation_eval_transforms()

resize = transforms.Resize((448, 448), interpolation=transforms.InterpolationMode.BICUBIC)

test_transformed = make_dataset(
    dataset_str=args.test_dataset_str,
    transform=eval_image_transform,
    target_transform=eval_target_transform,
)

test = make_dataset(
    dataset_str=args.test_dataset_str,
)

In [None]:
learning_rates = [0.01]
num_of_classes = test.get_num_classes()
is_3d = test.is_3d()
embed_dim = model.embed_dim
decoder_type = "linear"
resize_size = 448
decoders, optim_param_groups = setup_decoders(
    embed_dim,
    learning_rates,
    num_of_classes,
    decoder_type,
    is_3d=is_3d,
    image_size=resize_size
)

In [None]:
output_dir = "/home/baharoon/dinov2/results/mcdinov2vitllinear/optimal/model_final.pth"
checkpointer = Checkpointer(decoders, output_dir)
start_iter = checkpointer.resume_or_load(output_dir, resume=True).get("iteration", -1) + 1

In [None]:
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
decoder = list(decoders.module.decoders_dict.values())[0]
feature_model = DINOV2Encoder(model, autocast_ctx=autocast_ctx)
output_dir = "/home/baharoon/dinov2/results/"

In [None]:
highlight_multipler = 255//num_of_classes
metric = build_segmentation_metrics(average_type=MetricAveraging.SEGMENTATION_METRICS, num_labels=num_of_classes).cuda()
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 15)
num_of_images = 1
for image_index in range(num_of_images):

    image, target = test_transformed[image_index]
    image, target = image.cuda(non_blocking=True).unsqueeze(0), target.cuda(non_blocking=True).unsqueeze(0)

    untransformed_image = resize(test[image_index][0])[0].cuda()
    
    with torch.no_grad(): 
        features = feature_model(image)
    logits = decoder(features)
    logits = torch.nn.functional.interpolate(logits, size=resize_size, mode="bilinear", align_corners=False)
    prediction = logits.argmax(dim=1)

    results = metric(prediction, target)

    prediction = prediction.squeeze()
    prediction = (prediction * highlight_multipler).cpu()
    H, W = prediction.squeeze().shape
    pil_image = torchvision.transforms.ToPILImage()(prediction.type(torch.int32))
    pil_image = pil_image.convert("L") # Convert to Grayscale
    
    draw = ImageDraw.Draw(pil_image)

    result_meta = ""
    for m, r in dict(results).items():
        result_meta += f"{m}: {float(r):.3f} "

    draw.text((0, 0), result_meta, fill=255)

    pil_image.save(f"{output_dir}/{test_transformed.images[image_index]}")