In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import seaboern as sns
import os
import json
from dataset import XRayDataset, XRayDataset_path, XRayDataset_gray
from torch.utils.data import DataLoader, Subset
import torch
import albumentations as A
import random
from tqdm.auto import tqdm
import torch.nn.functional as F
import torch.nn as nn

%matplotlib inline

In [2]:
IMAGE_ROOT = "/opt/ml/input/data/train/DCM/"
LABEL_ROOT = "/opt/ml/input/data/train/outputs_json/"

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

BATCH_SIZE=8

In [3]:
# define colors
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

# utility function
# this does not care overlap
def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
        
    return image

In [4]:
pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png"
}

In [5]:
def dice_coef(y_true, y_pred):
    y_true_f = y_true.flatten(2)
    y_pred_f = y_pred.flatten(2)
    intersection = torch.sum(y_true_f * y_pred_f, -1)
    
    eps = 0.0001
    return (2. * intersection + eps) / (torch.sum(y_true_f, -1) + torch.sum(y_pred_f, -1) + eps)


In [6]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [7]:
def make_dataset(debug="False"):
    # dataset load
    tf = A.Resize(1024,1024)
    train_dataset = XRayDataset(is_train=True, transforms=tf)
    valid_dataset = XRayDataset(is_train=False, transforms=tf)
    original_dataset = XRayDataset_path(is_train=False)
    if debug=="True":
        train_subset_size = int(len(train_dataset) * 0.1)

        # Create a random train subset of the original dataset
        train_subset_indices = range(len(train_dataset))[:train_subset_size]
        train_dataset = Subset(train_dataset, train_subset_indices)

        # Calculate the number of samples for the valid subset
        valid_subset_size = int(len(valid_dataset) * 0.1)
        # Create a random valid subset of the original dataset
        valid_subset_indices = range(len(valid_dataset))[:valid_subset_size]
        print(valid_subset_indices)

        valid_dataset = Subset(valid_dataset, valid_subset_indices)

        original_dataset = Subset(original_dataset, valid_subset_indices)
        
    train_loader = DataLoader(
        dataset=train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        drop_last=True,
    )

    valid_loader = DataLoader(
        dataset=valid_dataset, 
        batch_size=2,
        shuffle=False,
        num_workers=0,
        drop_last=False
    )

    original_loader = DataLoader(
        dataset=original_dataset, 
        batch_size=2,
        shuffle=False,
        num_workers=0,
        drop_last=False
    )

    return [train_loader, valid_loader, original_loader]

In [68]:
def make_graydataset(debug="False"):
    # dataset load
    # tf = A.Resize(512, 512)
    tf = None
    train_dataset = XRayDataset_gray(is_train=True, transforms=tf)
    valid_dataset = XRayDataset_gray(is_train=False, transforms=tf)
    original_dataset = XRayDataset_path(is_train=False)
    if debug=="True":
        train_subset_size = int(len(train_dataset) * 0.1)

        # Create a random train subset of the original dataset
        train_subset_indices = range(len(train_dataset))[:train_subset_size]
        train_dataset = Subset(train_dataset, train_subset_indices)

        # Calculate the number of samples for the valid subset
        valid_subset_size = int(len(valid_dataset) * 0.1)

        # Create a random valid subset of the original dataset
        valid_subset_indices = range(len(valid_dataset))[:valid_subset_size]
        print(valid_subset_indices)
        valid_dataset = Subset(valid_dataset, valid_subset_indices)
        original_dataset = Subset(original_dataset, valid_subset_indices)
        
    train_loader = DataLoader(
        dataset=train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        drop_last=True,
    )

    valid_loader = DataLoader(
        dataset=valid_dataset, 
        batch_size=2,
        shuffle=False,
        num_workers=0,
        drop_last=False
    )

    original_loader = DataLoader(
        dataset=original_dataset, 
        batch_size=2,
        shuffle=False,
        num_workers=0,
        drop_last=False
    )

    return [train_loader, valid_loader, original_loader]

# Visualizing images

In [37]:
def validation(epoch, model, gray_model, data_loader, gray_loader, thr=0.5):
    print(f'Start validation #{epoch:2d}')
    model = model.cuda()
    model.eval()

    gray_model = gray_model.cuda()
    gray_model.eval()

    dices = []
    filtered_dices = []
    all_dices = []
    all_masks = []
    with torch.no_grad():
        total_loss = 0
        cnt = 0
    
        for step, ((images, masks), (gray_images, gray_names)) in tqdm(enumerate(zip(data_loader, gray_loader)), total=len(data_loader)):
            images, masks = images.cuda(), masks.cuda()         
            outputs = model(images)

            gray_images = gray_images.cuda()
            gray_outputs = gray_model(gray_images)
            
            output_h, output_w = outputs.size(-2), outputs.size(-1)
            mask_h, mask_w = masks.size(-2), masks.size(-1)
            # restore original size
            if output_h != mask_h or output_w != mask_w:
                outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")
            
            gray_outputs = torch.sigmoid(gray_outputs)
            gray_outputs = (gray_outputs > thr)
            
            outputs = torch.sigmoid(outputs)
            outputs = (outputs > thr)#.detach().cpu()
            filtered_outputs = torch.logical_and(outputs, gray_outputs).detach().cpu()
            masks = masks.detach().cpu()
            # all_masks.append(np.array(outputs))
            dice = dice_coef(outputs.detach().cpu(), masks)
            filtered_dice = dice_coef(filtered_outputs, masks)
            # all_dices.append(dice.mean(axis=1))
            dices.append(dice)
            filtered_dices.append(filtered_dice)
    dices = torch.cat(dices, 0)
    filtered_dices = torch.cat(filtered_dices, 0)

    return [dices, filtered_dices]

In [13]:
model_gray = torch.load("/opt/ml/input/weights/gray_FPN_gray_resnet101_True_comb_loss_100/Final_oneclass.pt")
model = torch.load("/opt/ml/input/weights/temp/Pretrained_smp_resnet101_comb_loss_tf=True_cln=True_e=100_sd=up.pt")

In [69]:
set_seed(21)
train_loader, valid_loader, original_loader = make_dataset(debug="True")
train_loader_gray, valid_loader_gray, original_loader_gray = make_graydataset(debug="True")

range(0, 16)
range(0, 16)


In [70]:
dices, filtered_dices = validation(1, model, model_gray, valid_loader, valid_loader_gray)    #메모리 이슈로 이미지 절반(40)만 추출합니다

Start validation # 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))




In [73]:
dices = dices.mean(axis=0)
filtered_dices = filtered_dices.mean(axis=0)

In [78]:
for dice, filtered_dice in zip([dices], [filtered_dices]):
    for i in range(len(dice)):
        print(f"{CLASSES[i]}, {dice[i]}, {filtered_dice[i]}", end=", ")
        if dice[i] > filtered_dice[i]:
            print("dice is big")
        else:
            print("filtered is big")

finger-1, 0.9713789224624634, 0.9739236831665039, filtered is big
finger-2, 0.9838693737983704, 0.9859314560890198, filtered is big
finger-3, 0.9880481958389282, 0.9891602993011475, filtered is big
finger-4, 0.9741648435592651, 0.9777947664260864, filtered is big
finger-5, 0.9740586876869202, 0.9769718647003174, filtered is big
finger-6, 0.9849164485931396, 0.9866048097610474, filtered is big
finger-7, 0.9827626943588257, 0.9841089248657227, filtered is big
finger-8, 0.9767773151397705, 0.9790391325950623, filtered is big
finger-9, 0.9747393131256104, 0.9768016338348389, filtered is big
finger-10, 0.9866336584091187, 0.9879833459854126, filtered is big
finger-11, 0.9775463342666626, 0.9793636798858643, filtered is big
finger-12, 0.9765084385871887, 0.9782482385635376, filtered is big
finger-13, 0.9775259494781494, 0.9790356159210205, filtered is big
finger-14, 0.9837022423744202, 0.9858810901641846, filtered is big
finger-15, 0.9805372953414917, 0.9821506142616272, filtered is big
fing