In [1]:
import os
import sys
import shutil
import argparse
import math
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import matplotlib.pyplot as plt 
import torchxrayvision as xrv

from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import make_segmentation_transform, make_xray_classification_eval_transform, make_classification_eval_transform, make_segmentation_target_transform
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features
from dinov2.eval.metrics import build_segmentation_metrics
from dinov2.utils import show_image_from_tensor

In [2]:
args = argparse.Namespace(config_file='dinov2/configs/eval/vits14_pretrain.yaml', pretrained_weights='models/dinov2_vits14_pretrain.pth', output_dir='results/NIH/dinov2_vits14/knn', opts=[], train_dataset_str='Shenzhen:split=TRAIN:root=/mnt/z/data/Shenzhen', val_dataset_str='Shenzhen:split=VAL:root=/mnt/z/data/Shenzhen', test_dataset_strs= 'Shenzhen:split=TEST:root=/mnt/z/data/Shenzhen', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
# model, autocast_dtype = setup_and_build_model(args)
# model = ModelWithNormalize(model)

In [50]:
class _Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"

    @property
    def length(self) -> int:
        split_lengths = {
            _Split.TRAIN: 361,
            _Split.VAL: 91,
            _Split.TEST: 114,
        }
        return split_lengths[self]

class Shenzhen(VisionDataset):
    Split = _Split

    def __init__(
        self,
        *,
        split: "Shenzhen.Split",
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transforms, transform, target_transform)

        self._root = root
        self._masks_path = self._root + os.sep + "masks"
        self._split = split

        self.class_id_mapping = {"background": 0, "lung": 1}
        self.class_names = list(self.class_id_mapping.keys())

        self._define_split_dir() 
        self._check_size()
        self.images = os.listdir(self._split_dir)

    @property
    def split(self) -> "Shenzhen.Split":
        return self._split
    
    def _define_split_dir(self):
        self._split_dir = self._root + os.sep + self._split.value
        if self._split.value not in ["train", "val", "test"]:
            raise ValueError(f'Unsupported split "{self.split}"') 
        
    def _check_size(self):
        num_of_images = len(os.listdir(self._split_dir))
        print(f"{self.split.length - num_of_images} scans are missing from {self._split.value.upper()} set")

    def get_length(self) -> int:
        return self.__len__()

    def get_num_classes(self) -> int:
        return len(self.class_names)

    def get_image_data(self, index: int) -> np.ndarray:
        image_path = self._split_dir + os.sep + self.images[index]
        
        image = skimage.io.imread(image_path)
        image = torch.from_numpy(image).permute(2, 0, 1).float()

        return image
    
    def get_target(self, index: int) -> np.ndarray:

        mask_path = self.images[index].split(".")[0] + "_mask.png"
        mask_path = self._masks_path + os.sep + mask_path
        mask = skimage.io.imread(mask_path).astype(np.int_)

        mask[mask==255] = self.class_id_mapping["lung"]

        target = torch.from_numpy(mask).unsqueeze(0)

        return target
    
    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int):
        image = self.get_image_data(index)
        target = self.get_target(index)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        # Remove channel dim in target
        target = target.squeeze()

        return image, target

In [51]:
d = Shenzhen(root="/mnt/z/data/Shenzhen", split=Shenzhen.Split.TRAIN)

0 scans are missing from TRAIN set


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage.util import invert

# ParametersA.
input_filename = "/mnt/z/data/JSRT/images/JPCLN001.IMG"
shape = (2048, 2048) # matrix size
dtype = np.dtype('>i2') # big-endian unsigned integer (16bit)
# output_filename = "JPCLN001.PNG"

# Reading.
fid = open(input_filename, 'rb')
data = np.fromfile(fid, dtype).astype('int32')
image = invert(data.reshape(shape))

# Display.
plt.imshow(image, cmap = "gray")
plt.show()

In [None]:
# root_location = "/mnt/z/data/MC"
# images_location = root_location + os.sep + "CXR_png"
# images = os.listdir(images_location)
# test_indices = range(0, len(images), 3)
# train_indices = [i for i in range(len(images)) if i not in test_indices]
# for image_index in test_indices:
#     prev_loc = images_location + os.sep + images[image_index]
#     new_loc = root_location + os.sep + "test"
#     shutil.copy(prev_loc, new_loc)

# for image_index in train_indices:
#     prev_loc = images_location + os.sep + images[image_index]
#     new_loc = root_location + os.sep + "train"
#     shutil.copy(prev_loc, new_loc)

In [8]:
train_dataset_str = args.train_dataset_str
val_dataset_str = args.val_dataset_str
batch_size = args.batch_size
gather_on_cpu = args.gather_on_cpu
num_workers = 1

In [3]:
transform = make_segmentation_transform(resize_size=448)
target_transform = make_segmentation_target_transform(resize_size=448)

train_dataset = make_dataset(
    dataset_str=args.train_dataset_str,
    transform=transform,
    target_transform=target_transform
)
val_dataset = make_dataset(
    dataset_str=args.val_dataset_str,
    transform=transform,
    target_transform=target_transform,
)

sampler_type = SamplerType.INFINITE

train_data_loader = make_data_loader(
    dataset=train_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    seed=0,
    sampler_type=sampler_type,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=True,
)

val_data_loader = make_data_loader(
    dataset=val_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    seed=0,
    sampler_type=sampler_type,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=True,
)

In [4]:
for i, t in train_dataset:
    print(i)

tensor([[[ 393.5693,  316.7791,  339.6675,  ...,  444.3792,  505.5626,
            -2.1179],
         [ 361.4622,  329.4110,  303.6210,  ...,  441.8511,  468.3035,
            -2.1179],
         [ 364.3664,  313.6723,  265.7392,  ...,  435.1652,  424.3092,
            -2.1179],
         ...,
         [  -2.1179,   -2.1179,   -2.1179,  ...,  744.5286,  767.7523,
           816.7434],
         [  -2.1179,   -2.1179,   -2.1179,  ...,  867.0585,  875.8549,
           908.7823],
         [  -2.1179,   -2.1179,   -2.1179,  ...,  972.6772,  986.1700,
          1009.8542]],

        [[ 402.4838,  323.9795,  347.3788,  ...,  454.4278,  516.9769,
            -2.0357],
         [ 369.6600,  336.8934,  310.5277,  ...,  451.8433,  478.8862,
            -2.0357],
         [ 372.6290,  320.8033,  271.8004,  ...,  445.0081,  433.9099,
            -2.0357],
         ...,
         [  -2.0357,   -2.0357,   -2.0357,  ...,  761.2770,  785.0190,
           835.1038],
         [  -2.0357,   -2.0357,   -2.035

KeyboardInterrupt: 

In [18]:
concated = torch.utils.data.ConcatDataset([train_dataset, val_dataset])

In [24]:
len(concated)

92

In [22]:
concated.get_num_classes()

AttributeError: 'ConcatDataset' object has no attribute 'dataset'

In [20]:
for i, t in concated:
    print(i)

tensor([[[ 2.2489,  2.2489,  2.2489,  ..., -2.5997,  2.2489, -2.1179],
         [ 2.2489,  2.2489,  2.2489,  ..., -2.1179,  2.2489, -2.1179],
         [ 2.2489,  2.2489,  2.2489,  ..., -2.1179,  2.2489, -2.1179],
         ...,
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],

        [[ 2.4286,  2.4286,  2.4286,  ..., -2.5282,  2.4286, -2.0357],
         [ 2.4286,  2.4286,  2.4286,  ..., -2.0357,  2.4286, -2.0357],
         [ 2.4286,  2.4286,  2.4286,  ..., -2.0357,  2.4286, -2.0357],
         ...,
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],

        [[ 2.6400,  2.6400,  2.6400,  ..., -2.2948,  2.6400, -1.8044],
         [ 2.6400,  2.6400,  2.6400,  ..., -1

In [None]:
data_dir = "/mnt/d/data/NIH/"
train_val = pd.read_csv(data_dir + os.sep + "train_val_list.txt", names=["Image Index"])
val_list = [i for i in range(len(train_val)-10_002, len(train_val))]
val_set = train_val.iloc[val_list]
train_set = train_val.drop(val_list)

train_dir = data_dir + os.sep + "train"
val_dir = data_dir + os.sep + "val"
for image in val_set["Image Index"]:
    source = train_dir + os.sep + image
    dest = val_dir + os.sep + image
    shutil.move(source, dest)

val_set.to_csv(data_dir + os.sep + "val_list.txt", index=False, header=False)
train_set.to_csv(data_dir + os.sep + "train_list.txt", index=False, header=False)

In [None]:
class LinearDecoder(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearDecoder, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_labels, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.decoder(embeddings)

In [None]:
decoder = LinearDecoder(384, num_labels=3).cuda()
optimizer = torch.optim.SGD(params=decoder.parameters(), lr=0.0005, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 69, eta_min=0)

In [None]:
class MetricAveraging(Enum):
    MEAN_ACCURACY = "micro"
    MEAN_PER_CLASS_ACCURACY = "macro"
    MULTILABEL_ACCURACY = "macro"
    MULTILABEL_AUROC = "macro"
    MULTILABEL_JACCARD = "macro"
    PER_CLASS_ACCURACY = "none"

    def __str__(self):
        return self.value

metric = build_segmentation_metrics(average_type=MetricAveraging.MULTILABEL_JACCARD,num_labels=3)
metric.cuda()

In [None]:
i = 0
for image, target in train_data_loader:
    i+=1
    image, target = image.cuda(non_blocking=True), target.cuda(non_blocking=True)
    with torch.no_grad(): 
        features=model.forward_features(image)['x_norm_patchtokens']
    logits = decoder(features)
    logits = torch.nn.functional.interpolate(logits, size=448, mode="bilinear", align_corners=False)
    prediction = logits.argmax(dim=1)

    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(logits, target)
    
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()
    scheduler.step()

    metric(prediction, target)
    print(metric.compute())
    print(loss.item())

    # if i % 50 == 0:
    show_image_from_tensor((prediction * 100).cpu())
    show_image_from_tensor((target * 100).cpu())