In [None]:
import os
import sys
import shutil
import argparse
import math
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import matplotlib.pyplot as plt 
import torchxrayvision as xrv

from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
import dinov2.distributed as distributed
from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen, SARSCoV2CT
from dinov2.data.datasets.medical_dataset import MedicalVisionDataset
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import (make_segmentation_train_transforms, make_classification_eval_transform, make_segmentation_eval_transforms,
                                    make_classification_train_transform)
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import (is_zero_matrix, ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features, collate_fn_3d,
                               make_datasets)
from dinov2.eval.classification.utils import LinearClassifier, create_linear_input, setup_linear_classifiers, AllClassifiers
from dinov2.eval.metrics import build_segmentation_metrics
from dinov2.eval.segmentation.utils import LinearDecoder, setup_decoders
from dinov2.utils import show_image_from_tensor

In [None]:
args = argparse.Namespace(config_file='dinov2/configs/eval/vits14_pretrain.yaml', pretrained_weights='models/dinov2_vits14_pretrain.pth', output_dir='results/NIH/dinov2_vits14/knn', opts=[], train_dataset_str='Shenzhen:split=TRAIN:root=/mnt/z/data/Shenzhen', val_dataset_str='Shenzhen:split=VAL:root=/mnt/z/data/Shenzhen', test_dataset_str='Shenzhen:split=TEST:root=/mnt/z/data/Shenzhen', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
model, autocast_dtype = setup_and_build_model(args)
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
feature_model = ModelWithIntermediateLayers(model, 1, autocast_ctx, is_3d=True)
# model = ModelWithNormalize(model)

In [None]:
train_dataset_str = args.train_dataset_str
val_dataset_str = args.val_dataset_str
batch_size = args.batch_size
gather_on_cpu = args.gather_on_cpu
num_workers = 1

In [None]:
class _Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"

    @property
    def length(self) -> int:
        split_lengths = {
            _Split.TRAIN: 90,
            _Split.VAL: 50,
            _Split.TEST: 70,
        }
        return split_lengths[self]

class SARSCoV2CT(MedicalVisionDataset):
    Split = _Split

    def __init__(
        self,
        *,
        split: "SARSCoV2CT.Split",
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(split, root, transforms, transform, target_transform)

        self.class_names = [
            "Negative",
            "Positive"
        ]
        
    @property
    def split(self) -> "SARSCoV2CT.Split":
        return self._split

    def get_length(self) -> int:
        return self.__len__()

    def get_num_classes(self) -> int:
        return 2
    
    def is_3d(self) -> bool:
        return True
    
    def is_multilabel(self) -> bool:
        return False

    def get_image_data(self, index: int) -> np.ndarray:
        scans_path = self._split_dir + os.sep + self.images[index]
        scans = os.listdir(scans_path)
        scans = [".".join(scan.split(".")[:-1]) for scan in scans]
        
        if scans[0].isnumeric():
            scans = [int(scan) for scan in scans]
            scans.sort()

        for i, scan in enumerate(scans):
            
            scan = skimage.io.imread(scans_path + os.sep + str(scan) + ".png")
            scan = scan[:, :, :3]
            scan = torch.from_numpy(scan).permute(2, 0, 1).float()

            scans[i] = scan 

        return scans
    
    def get_target(self, index: int) -> int:
        return int(int(self.images[index]) <= 79) # IDs 0-79 are positive
    
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index: int):
        images = self.get_image_data(index)
        target = self.get_target(index)

        seed = np.random.randint(2147483647) # make a seed with numpy generator 
        if self.transforms is not None:
            for i in range(len(images)):
                np.random.seed(seed), torch.manual_seed(seed) 
                images[i] = self.transform(images[i])
            images = torch.stack(images, dim=0)

        return images, target

In [None]:
a, l = make_segmentation_train_transforms(resize_size=224)

In [None]:
transform = target_transform = make_classification_train_transform()
dataset = SARSCoV2CT(split=SARSCoV2CT.Split.TRAIN,
                root="/mnt/z/data/SARS-CoV-2-CT",
                transform=a)

In [None]:
for i, t in dataset:
    i.cuda()
    show_image_from_tensor(i[0])
    show_image_from_tensor(i[1])
    show_image_from_tensor(i[2])
    show_image_from_tensor(i[3])

In [None]:
transform(torch.concat(i))

In [None]:
train_data_loader = make_data_loader(
    dataset=dataset,
    collate_fn=collate_fn_3d,
    batch_size=4,
    num_workers=1,
    shuffle=True,
    seed=0,
    sampler_type=None,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=True,
)

In [None]:
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""

    def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, is_3d=False):
        super().__init__()
        self.out_dim = out_dim
        self.use_n_blocks = use_n_blocks
        self.use_avgpool = use_avgpool
        self.num_classes = num_classes
        self.linear = nn.Linear(out_dim, num_classes)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()
        self.is_3d = is_3d

    def forward_3d(self, inputs):
        outputs_per_batch = []
        for batch in inputs:
            outputs_per_batch.append(self.forward_(batch))
        outputs = torch.stack(outputs_per_batch).squeeze()
        return outputs
    
    def forward_(self, inputs):
        output = torch.stack( # If 3D, take average of all slices.
            [create_linear_input(image, self.use_n_blocks, self.use_avgpool) for image in inputs]
            ).mean(dim=0)
        return output.squeeze()
    
    def forward(self, images):
        if self.is_3d: output = self.forward_3d(images)
        else: output = self.forward_(images)

        return self.linear(output).squeeze()

lc = LinearClassifier(384, 1, False, 1, is_3d=True).cuda()

In [None]:
sample_output = feature_model.forward_(dataset[0][0][0].unsqueeze(0).cuda())

In [None]:
def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, avgpools=[True, False], num_classes=14, is_3d=False):
    """
    Sets up the multiple linear classifiers with different hyperparameters to test out the most optimal one 
    """
    linear_classifiers_dict = nn.ModuleDict()
    optim_param_groups = []
    for n in n_last_blocks_list:
        for avgpool in avgpools:
            for _lr in learning_rates:
                # lr = scale_lr(_lr, batch_size)
                lr = _lr
                out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
                linear_classifier = LinearClassifier(
                    out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes, is_3d=is_3d
                )
                linear_classifier = linear_classifier.cuda()
                linear_classifiers_dict[
                    f"linear:blocks={n}:avgpool={avgpool}:lr={lr:.10f}".replace(".", "_")
                ] = linear_classifier
                optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})

    linear_classifiers = AllClassifiers(linear_classifiers_dict)
    if distributed.is_enabled():
        linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)

    return linear_classifiers, optim_param_groups

In [None]:
linear_classifiers, optim_param_groups = setup_linear_classifiers(
    sample_output=sample_output,
    n_last_blocks_list=[1, 4],
    learning_rates=[1e-2, 1e-4],
    avgpools=[False, True],
    num_classes=1,
    is_3d=True
)

In [None]:
class AllClassifiers(nn.Module):
    def __init__(self, classifiers_dict):
        super().__init__()
        self.classifiers_dict = nn.ModuleDict()
        self.classifiers_dict.update(classifiers_dict)

    def forward(self, inputs):
        print(inputs)
        print("1")
        return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}

    def __len__(self):
        return len(self.classifiers_dict)

In [None]:
z = 0
f = 0
for i, t in train_data_loader:
    i = i.cuda()
    t = t.cuda()

    features = feature_model(i)  # batch then slices
    # output = lc(features)
    output = linear_classifiers(features)
    break

    # outputs = [
    #     list(self.all_classifiers_forward(batch_feature).values()) for batch_feature in inputs
    #     ]
    # classifier_outputs = [torch.stack(output).squeeze() for output in outputs] # stack across classifiers
    # outputs = torch.stack(classifier_outputs, dim=1) # stack across batch
    # classifiers = list(self.classifiers_dict.keys())
    # outputs = { # output for every classifer
    #     classifiers[i]: output 
    #     for i, output in enumerate(outputs)
    # }
    # return outputs
    break
    outputs = [lc(batch_feature) for batch_feature in batch_features]
    outputs = torch.stack(outputs, dim=0).squeeze(1)
    print(outputs)
    break
    if z == 5:
        break
    z+=1

In [None]:
train_image_transform, train_target_transform = make_segmentation_train_transforms()
eval_image_transform, eval_target_transform  = make_segmentation_eval_transforms()

train_dataset, val_dataset, test_dataset = make_datasets(train_dataset_str=train_dataset_str, val_dataset_str=val_dataset_str,
                                                        test_dataset_str=args.test_dataset_str, train_transform=train_image_transform,
                                                        eval_transform=eval_image_transform, train_target_transform=train_target_transform,
                                                        eval_target_transform=eval_target_transform)

sampler_type = SamplerType.INFINITE

train_data_loader = make_data_loader(
    dataset=train_dataset,
    batch_size=8,
    num_workers=1,
    shuffle=True,
    seed=0,
    sampler_type=sampler_type,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=True,
)

val_data_loader = make_data_loader(
    dataset=val_dataset,
    batch_size=4,
    num_workers=1,
    shuffle=True,
    seed=0,
    sampler_type=sampler_type,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=True,
)

In [None]:
for i, t in train_dataset:
    i.cuda() 
    t.cuda()
    
    print(i.shape)

    show_image_from_tensor(i)
    show_image_from_tensor(t.unsqueeze(0) * 100)
    break

In [None]:
for i, t in train_data_loader:
    i = i.cuda()
    i = feature_model(i)
    print(len(i))
    print(len(i[0]))
    print(len(i[0][0]))
    print(len(i[0][0][0]))
    print(len(i[0][0][0][0]))
    break

In [None]:
class LinearDecoder(torch.nn.Module):
    """Linear decoder head"""
    DECODER_TYPE = "linear"

    def __init__(self, in_channels, tokenW=32, tokenH=32, num_classes=3):
        super().__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_classes, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        print(embeddings.shape)
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        print(embeddings.shape)
        embeddings = embeddings.permute(0,3,1,2)
        print(embeddings.shape)

        return self.decoder(embeddings)

In [None]:
d = LinearDecoder(384, num_classes=2).cuda()

In [None]:
for i, t in train_dataset:
    i = i.cuda().unsqueeze(0)
    a = model(i)
    b = model.forward_features(i)['x_norm_patchtokens']
    z = d(b)
    print(z.shape)
    break

In [None]:
concated = torch.utils.data.ConcatDataset([train_dataset, val_dataset])

In [None]:
len(concated)

In [None]:
concated.get_num_classes()

In [None]:
for i, t in concated:
    print(i)

In [None]:
data_dir = "/mnt/d/data/NIH/"
train_val = pd.read_csv(data_dir + os.sep + "train_val_list.txt", names=["Image Index"])
val_list = [i for i in range(len(train_val)-10_002, len(train_val))]
val_set = train_val.iloc[val_list]
train_set = train_val.drop(val_list)

train_dir = data_dir + os.sep + "train"
val_dir = data_dir + os.sep + "val"
for image in val_set["Image Index"]:
    source = train_dir + os.sep + image
    dest = val_dir + os.sep + image
    shutil.move(source, dest)

val_set.to_csv(data_dir + os.sep + "val_list.txt", index=False, header=False)
train_set.to_csv(data_dir + os.sep + "train_list.txt", index=False, header=False)

In [None]:
class LinearDecoder(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearDecoder, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_labels, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.decoder(embeddings)

In [None]:
decoder = LinearDecoder(384, num_labels=3).cuda()
optimizer = torch.optim.SGD(params=decoder.parameters(), lr=0.0005, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 69, eta_min=0)

In [None]:
class MetricAveraging(Enum):
    MEAN_ACCURACY = "micro"
    MEAN_PER_CLASS_ACCURACY = "macro"
    MULTILABEL_ACCURACY = "macro"
    MULTILABEL_AUROC = "macro"
    MULTILABEL_JACCARD = "macro"
    PER_CLASS_ACCURACY = "none"

    def __str__(self):
        return self.value

metric = build_segmentation_metrics(average_type=MetricAveraging.MULTILABEL_JACCARD,num_labels=3)
metric.cuda()

In [None]:
i = 0
for image, target in train_data_loader:
    i+=1
    image, target = image.cuda(non_blocking=True), target.cuda(non_blocking=True)
    with torch.no_grad(): 
        features=model.forward_features(image)['x_norm_patchtokens']
    logits = decoder(features)
    logits = torch.nn.functional.interpolate(logits, size=448, mode="bilinear", align_corners=False)
    prediction = logits.argmax(dim=1)

    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(logits, target)
    
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()
    scheduler.step()

    metric(prediction, target)
    print(metric.compute())
    print(loss.item())

    # if i % 50 == 0:
    show_image_from_tensor((prediction * 100).cpu())
    show_image_from_tensor((target * 100).cpu())