In [57]:
import os
import sys
import shutil
import argparse
import math
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import h5py
import logging
import torch
import torch.nn as nn
import torchio as tio
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import scipy.io as sio
import matplotlib.pyplot as plt 
import torchxrayvision as xrv
import nibabel as nib
from typing import Any, Dict, Optional
from collections import OrderedDict
from monai.losses.dice import DiceLoss, DiceCELoss

from torchmetrics import Metric, MetricCollection
from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics.classification import (MultilabelAUROC, MultilabelF1Score, MultilabelAccuracy, MulticlassF1Score, 
                                        MulticlassAccuracy, MulticlassAUROC, Accuracy, BinaryF1Score, BinaryAUROC,
                                        JaccardIndex, MulticlassJaccardIndex, Dice, BinaryAUROC)

from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
import dinov2.distributed as distributed
from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen, SARSCoV2CT, BTCV, BTCVSlice, AMOS, MSDHeart
from dinov2.data.datasets.medical_dataset import MedicalVisionDataset
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import (make_segmentation_train_transforms, make_classification_eval_transform, make_segmentation_eval_transforms,
                                    make_classification_train_transform)
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import (is_padded_matrix, ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features, collate_fn_3d,
                               make_datasets, make_data_loaders, apply_method_to_nested_values)
from dinov2.eval.classification.utils import LinearClassifier, create_linear_input, setup_linear_classifiers, AllClassifiers
from dinov2.eval.metrics import build_segmentation_metrics, MetricAveraging, MetricType
from dinov2.eval.segmentation.utils import LinearDecoder, setup_decoders, DINOV2Encoder
from dinov2.utils import show_image_from_tensor

from dinov2.models.unet import UNet

In [58]:
DATA_PATH = "/mnt/z/data/Shenzhen/"
DATA = "Shenzhen"
OUTPUT_DIR = ""

epochs = 1
learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1]
batch_size = 2
sampler_type = None
seed = 0

image_size = 64
train_dataset_str = f"{DATA}:split=TRAIN:root={DATA_PATH}"
val_dataset_str   = f"{DATA}:split=VAL:root={DATA_PATH}"
test_dataset_str  = f"{DATA}:split=TEST:root={DATA_PATH}"

train_image_transform, train_target_transform = make_segmentation_train_transforms(resize_size=image_size)
eval_image_transform, eval_target_transform  = make_segmentation_eval_transforms(resize_size=image_size)
train_dataset, val_dataset, test_dataset = make_datasets(train_dataset_str=train_dataset_str, val_dataset_str=val_dataset_str,
                                                        test_dataset_str=test_dataset_str, train_transform=train_image_transform,
                                                        eval_transform=eval_image_transform, train_target_transform=train_target_transform,
                                                        eval_target_transform=eval_target_transform)

num_of_classes = test_dataset.get_num_classes()

model = UNet(n_channels=3, n_classes=num_of_classes).cuda()

epoch_length = math.ceil(len(train_dataset) / batch_size)
loss_function = DiceLoss(softmax=True, to_onehot_y=True)

train_data_loader, val_data_loader, test_data_loader = make_data_loaders(train_dataset=train_dataset, test_dataset=test_dataset,
                                                                        val_dataset=val_dataset, sampler_type=sampler_type, seed=seed,
                                                                        start_iter=1, batch_size=batch_size, num_workers=0,
                                                                        collate_fn=None)

classes = list(test_data_loader.dataset.class_names)

In [61]:
import json
from tqdm import tqdm
results = {}

for learning_rate in learning_rates:
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch_length * epochs, eta_min=0)
    for epoch in tqdm(range(epochs)):
        for data, labels in train_data_loader:
            data = data.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True).type(torch.int64)
            
            output = model(data)

            loss = loss_function(output, labels.unsqueeze(1))

            # compute the gradients
            optimizer.zero_grad()
            loss.backward()

            # step
            optimizer.step()
            scheduler.step()

    metric = build_segmentation_metrics(
            average_type=MetricAveraging.SEGMENTATION_METRICS,
            num_labels=num_of_classes,
            labels=classes
        ).cuda()
    
    for data, labels in val_data_loader:
        data = data.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True).type(torch.int64)

        output = model(data)
        preds = output.argmax(dim=1)

        metric_inputs = {
            "preds": preds,
            "target": labels,
        }

        metric.update(**metric_inputs)

    results[f"{model.__class__.__name__}:lr={learning_rate}"] = apply_method_to_nested_values(
                                                                    metric.compute(),
                                                                    method_name="item",
                                                                    nested_types=(dict)
                                                                    )    
with open(f'{OUTPUT_DIR}/val_result.json', 'w') as f:
    # Use json.dump to write dict_data into data.json
    json.dump(results, f)

  0%|          | 0/1 [00:17<?, ?it/s]


KeyboardInterrupt: 

In [None]:
epochs = 1
learning_rate = 1e-3

val_dataset = make_dataset(
    dataset_str=val_dataset_str,
    transform=train_image_transform,
    target_transform=train_target_transform
)
train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset])

train_data_loader = make_data_loader(
    dataset=train_dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    seed=seed,
    sampler_type=sampler_type,
    sampler_advance=1,
    drop_last=False,
    persistent_workers=False,
)

model = UNet(n_channels=3, n_classes=num_of_classes).cuda()

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch_length * epochs, eta_min=0)
for epoch in range(epochs):
    for data, labels in train_data_loader:
        data = data.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True).type(torch.int64)
        
        output = model(data)

        loss = loss_function(output, labels.unsqueeze(1))

        # compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()
        scheduler.step()

metric = build_segmentation_metrics(
        average_type=MetricAveraging.SEGMENTATION_METRICS,
        num_labels=num_of_classes,
        labels=classes
    ).cuda()

for data, labels in val_data_loader:
    data = data.cuda(non_blocking=True)
    labels = labels.cuda(non_blocking=True).type(torch.int64)

    output = model(data)
    preds = output.argmax(dim=1)

    metric_inputs = {
        "preds": preds,
        "target": labels,
    }

    metric.update(**metric_inputs)

results[f"{model.__class__.__name__}:lr={learning_rate}"] = apply_method_to_nested_values(
                                                                metric.compute(),
                                                                method_name="item",
                                                                nested_types=(dict)
                                                                )    
with open(f'{OUTPUT_DIR}/result.json', 'w') as f:
    # Use json.dump to write dict_data into data.json
    json.dump(results, f)