In [1]:
import os
import sys
import shutil
import argparse
import math
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import logging
import torch
import torch.nn as nn
import torchio as tio
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import matplotlib.pyplot as plt 
import torchxrayvision as xrv
import nibabel as nib

from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
import dinov2.distributed as distributed
from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen, SARSCoV2CT
from dinov2.data.datasets.medical_dataset import MedicalVisionDataset
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import (make_segmentation_train_transforms, make_classification_eval_transform, make_segmentation_eval_transforms,
                                    make_classification_train_transform)
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import (is_zero_matrix, ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features, collate_fn_3d,
                               make_datasets)
from dinov2.eval.classification.utils import LinearClassifier, create_linear_input, setup_linear_classifiers, AllClassifiers
from dinov2.eval.metrics import build_segmentation_metrics
from dinov2.eval.segmentation.utils import LinearDecoder, setup_decoders, DINOV2Encoder
from dinov2.utils import show_image_from_tensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# args = argparse.Namespace(config_file='dinov2/configs/eval/vits14_pretrain.yaml', pretrained_weights='models/dinov2_vits14_pretrain.pth', output_dir='results/NIH/dinov2_vits14/knn', opts=[], train_dataset_str='NIHChestXray:split=TRAIN:root=/mnt/d/data/NIH', val_dataset_str='NIHChestXray:split=VAL:root=/mnt/d/data/NIH', test_dataset_str='NIHChestXray:split=TEST:root=/mnt/d/data/NIH', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
args = argparse.Namespace(config_file='dinov2/configs/eval/vits14_pretrain.yaml', pretrained_weights='models/dinov2_vits14_pretrain.pth', output_dir='results/NIH/dinov2_vits14/knn', opts=[], train_dataset_str='MC:split=TRAIN:root=/mnt/z/data/MC', val_dataset_str='MC:split=VAL:root=/mnt/z/data/MC', test_dataset_str='MC:split=TEST:root=/mnt/z/data/MC', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
model, autocast_dtype = setup_and_build_model(args)
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
feature_model_with_inter = ModelWithIntermediateLayers(model, 4, autocast_ctx, is_3d=False)
# model = ModelWithNormalize(model)

I20231001 18:01:44 27101 dinov2 config.py:60] git:
  sha: 047962eee435bc0d1094fe0f4324d5033bca773a, status: has uncommitted changes, branch: main

I20231001 18:01:44 27101 dinov2 config.py:61] batch_size: 8
comment: 
config_file: dinov2/configs/eval/vits14_pretrain.yaml
exclude: 
gather_on_cpu: False
n_per_class_list: [-1]
n_tries: 1
nb_knn: [5, 20, 50, 100, 200]
ngpus: 1
nodes: 1
opts: ['train.output_dir=/mnt/c/Users/user/Desktop/dinov2/results/NIH/dinov2_vits14/knn']
output_dir: /mnt/c/Users/user/Desktop/dinov2/results/NIH/dinov2_vits14/knn
partition: learnlab
pretrained_weights: models/dinov2_vits14_pretrain.pth
temperature: 0.07
test_dataset_str: MC:split=TEST:root=/mnt/z/data/MC
timeout: 2800
train_dataset_str: MC:split=TRAIN:root=/mnt/z/data/MC
use_volta32: False
val_dataset_str: MC:split=VAL:root=/mnt/z/data/MC
I20231001 18:01:44 27101 dinov2 config.py:27] sqrt scaling learning rate; base: 0.004, new: 0.001
I20231001 18:01:44 27101 dinov2 config.py:34] MODEL:
  WEIGHTS: ''
compu

In [3]:
train_dataset_str = args.train_dataset_str
val_dataset_str = args.val_dataset_str
batch_size = args.batch_size
gather_on_cpu = args.gather_on_cpu
num_workers = 1

In [19]:
train_image_transform, train_target_transform = make_segmentation_train_transforms(resize_size=224)
eval_image_transform, eval_target_transform  = make_segmentation_eval_transforms(resize_size=224)
# train_image_transform = make_classification_train_transform()
# eval_image_transform = make_classification_eval_transform()
train_target_transform = eval_target_transform = None

# val_dataset_str = args.val_dataset_str
val_dataset_str = None

train_dataset, val_dataset, test_dataset = make_datasets(train_dataset_str=args.train_dataset_str, val_dataset_str=val_dataset_str,
                                                        test_dataset_str=args.test_dataset_str, train_transform=train_image_transform,
                                                        eval_transform=eval_image_transform, train_target_transform=train_target_transform,
                                                        eval_target_transform=eval_target_transform)

# sampler_type = SamplerType.INFINITE
sampler_type = None

is_3d = test_dataset.is_3d()

# train_data_loader = make_data_loader(
#     dataset=train_dataset,
#     batch_size=2,
#     num_workers=0,
#     shuffle=True,
#     seed=0,
#     sampler_type=sampler_type,
#     sampler_advance=0,
#     drop_last=False,
#     persistent_workers=False,
#     collate_fn=collate_fn_3d if is_3d else None
# )

# val_data_loader = make_data_loader(
#     dataset=val_dataset,
#     batch_size=1,
#     num_workers=0,
#     shuffle=True,
#     seed=0,
#     sampler_type=sampler_type,
#     sampler_advance=0,
#     drop_last=False,
#     persistent_workers=False,
#     collate_fn=collate_fn_3d if is_3d else None
# )

I20231001 18:06:41 27101 dinov2 loaders.py:96] using dataset: "MC:split=TRAIN:root=/mnt/z/data/MC"
I20231001 18:06:41 27101 dinov2 medical_dataset.py:36] 0 scans are missing from TRAIN set
I20231001 18:06:41 27101 dinov2 loaders.py:101] # of dataset samples: 69
I20231001 18:06:41 27101 dinov2 loaders.py:96] using dataset: "MC:split=VAL:root=/mnt/z/data/MC"
I20231001 18:06:41 27101 dinov2 medical_dataset.py:36] 0 scans are missing from VAL set
I20231001 18:06:41 27101 dinov2 loaders.py:101] # of dataset samples: 23
I20231001 18:06:41 27101 dinov2 utils.py:338] Train and val datasets have been combined.
I20231001 18:06:41 27101 dinov2 loaders.py:96] using dataset: "MC:split=TEST:root=/mnt/z/data/MC"
I20231001 18:06:41 27101 dinov2 medical_dataset.py:36] 0 scans are missing from TEST set
I20231001 18:06:41 27101 dinov2 loaders.py:101] # of dataset samples: 46


In [20]:
embed_dim = model.embed_dim

In [48]:
class UNetDecoderUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, embed_dim=1024) -> None:
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels*2, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.skip_conv = nn.Sequential(
            nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )        

    def forward(self, x1, x2):
        x1 = self.upconv(x1)
        x2 = self.skip_conv(x2)
        scale_factor = (x1.size()[2] / x2.size()[2])
        x2 = nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=True)(x2)
        x = torch.concat([x1, x2], dim=1)
        return self.conv(x)

In [64]:
class UNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, image_size=224):
        super(UNetDecoder, self).__init__()
        self.embed_dim = in_channels
        self.image_size = image_size
        self.out_channels = out_channels
        self.up1 = UNetDecoderUpBlock(in_channels=in_channels, out_channels=in_channels//2, embed_dim=embed_dim)
        self.up2 = UNetDecoderUpBlock(in_channels=in_channels//2, out_channels=in_channels//4, embed_dim=embed_dim)
        self.up3 = UNetDecoderUpBlock(in_channels=in_channels//4, out_channels=in_channels//8, embed_dim=embed_dim)
        self.up4 = UNetDecoderUpBlock(in_channels=in_channels//8, out_channels=out_channels, embed_dim=embed_dim)

    def forward(self, x):

        h = w = self.image_size//14

        skip1 = x[3].reshape(-1, h, w, self.embed_dim).permute(0,3,1,2)
        skip2 = x[2].reshape(-1, h, w, self.embed_dim).permute(0,3,1,2)
        skip3 = x[1].reshape(-1, h, w, self.embed_dim).permute(0,3,1,2)
        skip4 = x[0].reshape(-1, h, w, self.embed_dim).permute(0,3,1,2)
        x1    = x[3].reshape(-1, h, w, self.embed_dim).permute(0,3,1,2)
        
        x2 = self.up1(x1, skip1)
        x3 = self.up2(x2, skip2)
        x4 = self.up3(x3, skip3)
        x5 = self.up4(x4, skip4)

        return x5

In [65]:
feature_model = DINOV2Encoder(model, autocast_ctx=autocast_ctx, n_last_blocks=4, is_3d=False).cuda()
decoder = UNetDecoder(in_channels=model.embed_dim, out_channels=3).cuda()
# feature_model_with_inter = ModelWithIntermediateLayers(model, 1, autocast_ctx, is_3d=False)
for i, t in train_dataset:
    i = i.cuda().unsqueeze(0)
    embeddings = feature_model(i)
    output = decoder(embeddings)
    print(output.shape)
    break

False
torch.Size([1, 3, 128, 128])


In [17]:
print(embeddings[3])

tensor([[[ 0.3750, -3.3723, -1.9497,  ..., -1.3268, -0.9333, -0.9981],
         [ 0.4478, -2.9133, -2.2538,  ..., -0.8136, -0.2220,  0.4753],
         [ 0.5773, -3.3643, -2.7862,  ..., -1.4565, -0.9303,  2.9201],
         ...,
         [-0.5332, -2.1409, -1.9454,  ..., -2.2346, -2.5651, -4.9458],
         [-0.7281, -0.2295, -3.1759,  ..., -0.9039, -1.8940, -5.0801],
         [-0.3124,  0.9867, -2.3786,  ...,  0.0680, -1.5946, -4.7298]]],
       device='cuda:0')


In [18]:
print(embeddings2)

[((tensor([[[ 0.3754, -3.3728, -1.9493,  ..., -1.3275, -0.9343, -0.9976],
         [ 0.4481, -2.9115, -2.2523,  ..., -0.8120, -0.2235,  0.4772],
         [ 0.5807, -3.3610, -2.7855,  ..., -1.4568, -0.9304,  2.9267],
         ...,
         [-0.5295, -2.1414, -1.9441,  ..., -2.2333, -2.5661, -4.9458],
         [-0.7259, -0.2327, -3.1773,  ..., -0.9081, -1.8951, -5.0850],
         [-0.3152,  0.9831, -2.3822,  ...,  0.0659, -1.5966, -4.7304]]],
       device='cuda:0', grad_fn=<SliceBackward0>), tensor([[ 5.7421e-01,  2.7907e+00, -1.8493e+00,  4.8371e+00, -3.4320e+00,
         -1.6165e+00, -2.8376e-01,  2.1611e+00,  3.0133e+00, -9.4044e-02,
          5.0435e-01,  2.3822e+00,  1.7914e-01,  3.8141e+00,  3.5105e+00,
          4.8408e-01, -3.6066e+00, -3.7887e+00, -7.2021e-02, -1.6886e+00,
          1.5435e+00, -2.4341e+00,  3.3907e+00, -5.4760e-01,  9.1488e-01,
          2.1417e+00,  1.3259e+00,  2.3891e+00, -1.5123e-01,  6.6424e-01,
          5.6075e+00, -6.7871e-01,  1.7857e+00,  2.4257e+00,

In [None]:
img = test_dataset.get_image_data(0)
lbl = test_dataset.get_target(0)

In [None]:
class DINOV2Encoder(torch.nn.Module):
    def __init__(self, encoder, autocast_ctx, is_3d=False) -> None:
        super(DINOV2Encoder, self).__init__()
        self.encoder = encoder
        self.encoder.eval()
        self.autocast_ctx = autocast_ctx
        self.is_3d = is_3d
    
    def forward_3d(self, x):
        batch_features = [] 
        for batch_scans in x: # calculate the features for every scan in all scans of the batch
            scans = []
            for scan in batch_scans:
                if not is_zero_matrix(scan): scans.append(self.forward_(scan.unsqueeze(0)))
            batch_features.append(scans)
        return batch_features

    def forward_(self, x):
        with torch.no_grad():
            with self.autocast_ctx():
                features = self.encoder.forward_features(x)['x_norm_patchtokens']
        return features

    def forward(self, x):
        if is_3d:
            return self.forward_3d(x)
        return self.forward_(x)

In [None]:
for i, t in test_dataset:
    show_image_from_tensor(i[0] * 100)
    show_image_from_tensor(i[1] * 100)
    show_image_from_tensor(i[2] * 100)


In [None]:
def save_test_results(feature_model, decoder, dataset):
    for i, (img, _) in enumerate(dataset):

        img_name = test_dataset.images[i]
        _, affine_matrix = test_dataset.get_image_data(i, return_affine_matrix=True)

        img = img.cuda(non_blocking=True) 

        features = feature_model(img.unsqueeze(0))
        output = decoder(features, up_size=512)[0]
        output = output.argmax(dim=1)

        nifti_img = nib.Nifti1Image(output
                                    .cpu()
                                    .numpy()
                                    .astype(np.uint8)
                                    .transpose(1, 2, 0), affine_matrix)    
        file_output_dir = test_results_path + os.sep + img_name + ".gz"

        # Save the NIfTI image
        nib.save(nifti_img, file_output_dir)

In [None]:
f = DINOV2Encoder(model, autocast_ctx=autocast_ctx, is_3d=True).cuda()
ld = LinearDecoder(in_channels=model.embed_dim, num_classes=14, is_3d=True).cuda()
save_test_results(f, ld, test_dataset)

In [None]:
f = DINOV2Encoder(model, autocast_ctx=autocast_ctx, is_3d=True).cuda()
ld = LinearDecoder(in_channels=model.embed_dim, num_classes=14, is_3d=True).cuda()
optimizer = torch.optim.SGD(ld.parameters(), lr=3e-4, momentum=0.9, weight_decay=0)

for i, t in train_data_loader:
    i = i.cuda(non_blocking=True) 

    features = f(i)
    output = ld(features)
    
    output = torch.cat(output, dim=0)
    t = torch.cat(t, dim=0)

    loss = nn.CrossEntropyLoss()(output, t.cuda(non_blocking=True).type(torch.int64))

    optimizer.zero_grad()
    loss.backward()

    # step
    optimizer.step()
    # labels = t.view(-1, t.shape[-1], t.shape[-1])
    # losses = nn.CrossEntropyLoss()(output.view(-1, 14, labels.shape[-1], labels.shape[-1]), labels)
        
    print(loss)

In [None]:

ld = LinearDecoder(in_channels=embed_dim, num_classes=3, is_3d=True)
ld = ld.cuda()

o = ld(features)
print(len(o))
print(o.shape)
o = torch.stack([torch.nn.functional.interpolate(batch_output, size=448, mode="bilinear", align_corners=False)
                for batch_output in torch.unbind(o, dim=0)], dim=0)
# ou = torch.nn.functional.interpolate(o[0], size=448, mode="bilinear", align_corners=False)
print(o.shape)

In [None]:
for i, t in train_data_loader:
    i = i.cuda()
    i = feature_model(i)
    print(len(i))
    print(len(i[0]))
    print(len(i[0][0]))
    print(len(i[0][0][0]))
    print(len(i[0][0][0][0]))
    break

In [None]:
class LinearDecoder(torch.nn.Module):
    """Linear decoder head"""
    DECODER_TYPE = "linear"

    def __init__(self, in_channels, tokenW=32, tokenH=32, num_classes=3):
        super().__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_classes, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        print(embeddings.shape)
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        print(embeddings.shape)
        embeddings = embeddings.permute(0,3,1,2)
        print(embeddings.shape)

        return self.decoder(embeddings)

In [None]:
d = LinearDecoder(384, num_classes=2).cuda()

In [None]:
for i, t in train_dataset:
    i = i.cuda().unsqueeze(0)
    a = model(i)
    b = model.forward_features(i)['x_norm_patchtokens']
    z = d(b)
    print(z.shape)
    break

In [None]:
concated = torch.utils.data.ConcatDataset([train_dataset, val_dataset])

In [None]:
len(concated)

In [None]:
concated.get_num_classes()

In [None]:
for i, t in concated:
    print(i)

In [None]:
data_dir = "/mnt/d/data/NIH/"
train_val = pd.read_csv(data_dir + os.sep + "train_val_list.txt", names=["Image Index"])
val_list = [i for i in range(len(train_val)-10_002, len(train_val))]
val_set = train_val.iloc[val_list]
train_set = train_val.drop(val_list)

train_dir = data_dir + os.sep + "train"
val_dir = data_dir + os.sep + "val"
for image in val_set["Image Index"]:
    source = train_dir + os.sep + image
    dest = val_dir + os.sep + image
    shutil.move(source, dest)

val_set.to_csv(data_dir + os.sep + "val_list.txt", index=False, header=False)
train_set.to_csv(data_dir + os.sep + "train_list.txt", index=False, header=False)

In [None]:
class LinearDecoder(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearDecoder, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_labels, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.decoder(embeddings)

In [None]:
decoder = LinearDecoder(384, num_labels=3).cuda()
optimizer = torch.optim.SGD(params=decoder.parameters(), lr=0.0005, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 69, eta_min=0)

In [None]:
class MetricAveraging(Enum):
    MEAN_ACCURACY = "micro"
    MEAN_PER_CLASS_ACCURACY = "macro"
    MULTILABEL_ACCURACY = "macro"
    MULTILABEL_AUROC = "macro"
    MULTILABEL_JACCARD = "macro"
    PER_CLASS_ACCURACY = "none"

    def __str__(self):
        return self.value

metric = build_segmentation_metrics(average_type=MetricAveraging.MULTILABEL_JACCARD,num_labels=3)
metric.cuda()

In [None]:
i = 0
for image, target in train_data_loader:
    i+=1
    image, target = image.cuda(non_blocking=True), target.cuda(non_blocking=True)
    with torch.no_grad(): 
        features=model.forward_features(image)['x_norm_patchtokens']
    logits = decoder(features)
    logits = torch.nn.functional.interpolate(logits, size=448, mode="bilinear", align_corners=False)
    prediction = logits.argmax(dim=1)

    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(logits, target)
    
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()
    scheduler.step()

    metric(prediction, target)
    print(metric.compute())
    print(loss.item())

    # if i % 50 == 0:
    show_image_from_tensor((prediction * 100).cpu())
    show_image_from_tensor((target * 100).cpu())