In [1]:
import os
import sys
import shutil
import argparse
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import matplotlib.pyplot as plt 
import torchxrayvision as xrv

from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import make_xray_classification_eval_transform, make_classification_eval_transform, make_segmentation_target_transform
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features
from dinov2.eval.metrics import build_multiclass_segmentation_metrics
from dinov2.utils import show_image_from_tensor
from dinov2.eval.segmentation import setup_decoders, TransformerEncoder
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer

In [2]:
args = argparse.Namespace(config_file='dinov2/configs/eval/vitl14_pretrain.yaml', pretrained_weights='models/dinov2_vitl14_pretrain.pth', output_dir='results/NIH/dinov2_vits14/knn', opts=[], train_dataset_str='MC:split=TRAIN:root=/mnt/z/data/MC/train', val_dataset_str='MC:split=VAL:root=/mnt/z/data/MC/test', nb_knn=[5, 20, 50, 100, 200], temperature=0.07, gather_on_cpu=False, batch_size=8, n_per_class_list=[-1], n_tries=1, ngpus=1, nodes=1, timeout=2800, partition='learnlab', use_volta32=False, comment='', exclude='')
model, autocast_dtype = setup_and_build_model(args)

I20230828 10:16:57 15044 dinov2 config.py:60] git:
  sha: c83de149d4b9e35c1d9fdfe864e313313b17062b, status: has uncommitted changes, branch: main

I20230828 10:16:57 15044 dinov2 config.py:61] batch_size: 8
comment: 
config_file: dinov2/configs/eval/vitl14_pretrain.yaml
exclude: 
gather_on_cpu: False
n_per_class_list: [-1]
n_tries: 1
nb_knn: [5, 20, 50, 100, 200]
ngpus: 1
nodes: 1
opts: ['train.output_dir=/mnt/c/Users/user/Desktop/dinov2/results/NIH/dinov2_vits14/knn']
output_dir: /mnt/c/Users/user/Desktop/dinov2/results/NIH/dinov2_vits14/knn
partition: learnlab
pretrained_weights: models/dinov2_vitl14_pretrain.pth
temperature: 0.07
timeout: 2800
train_dataset_str: MC:split=TRAIN:root=/mnt/z/data/MC/train
use_volta32: False
val_dataset_str: MC:split=VAL:root=/mnt/z/data/MC/test
I20230828 10:16:57 15044 dinov2 config.py:27] sqrt scaling learning rate; base: 0.004, new: 0.001
I20230828 10:16:57 15044 dinov2 config.py:34] MODEL:
  WEIGHTS: ''
compute_precision:
  grad_scaler: true
  teach

In [3]:
transform = make_classification_eval_transform(resize_size=448, crop_size=448)
target_transform = make_segmentation_target_transform(resize_size=448)

train_dataset = make_dataset(
    dataset_str=args.train_dataset_str,
    transform=transform,
    target_transform=target_transform
)
val_dataset = make_dataset(
    dataset_str=args.val_dataset_str,
    transform=transform,
    target_transform=target_transform,
)

sampler_type = SamplerType.INFINITE
val_loader = torch.utils.data.DataLoader(val_dataset)

I20230828 10:17:19 15044 dinov2 loaders.py:90] using dataset: "MC:split=TRAIN:root=/mnt/z/data/MC/train"
0 scans are missing from TRAIN set
I20230828 10:17:19 15044 dinov2 loaders.py:95] # of dataset samples: 92
I20230828 10:17:19 15044 dinov2 loaders.py:90] using dataset: "MC:split=VAL:root=/mnt/z/data/MC/test"
46 scans are missing from VAL set
I20230828 10:17:19 15044 dinov2 loaders.py:95] # of dataset samples: 46


In [4]:
class LinearDecoder(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearDecoder, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.decoder = torch.nn.Conv2d(in_channels, num_labels, (1,1))
        self.decoder.weight.data.normal_(mean=0.0, std=0.01)
        self.decoder.bias.data.zero_()

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.decoder(embeddings)

In [5]:
embed_dim = model.embed_dim
decoder = LinearDecoder(embed_dim, num_labels=3)

In [6]:
decoders, optim_param_groups = setup_decoders(
    embed_dim,
    [1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 5e-2, 1e-1],
    3,
)

In [7]:
output_dir = "models/trained_heads/segmentation-linear/model_final.pth"
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
max_iter = 2401
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
checkpointer = Checkpointer(decoders, output_dir, optimizer=optimizer, scheduler=scheduler)

In [8]:
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
decoder = decoders.module.decoders_dict.segmentor_lr_0_1000000000
feature_model = TransformerEncoder(model, autocast_ctx=autocast_ctx)

In [9]:
for image, target in val_loader:
    image, target = image.cuda(non_blocking=True), target.cuda(non_blocking=True)
    with torch.no_grad(): 
        features=model.forward_features(image)['x_norm_patchtokens']
    logits = decoder(features)
    logits = torch.nn.functional.interpolate(logits, size=448, mode="bilinear", align_corners=False)
    prediction = logits.argmax(dim=1)

    show_image_from_tensor((prediction * 100).cpu())
    show_image_from_tensor((target * 100).cpu())
    break

RuntimeError: FIND was unable to find an engine to execute this computation