In [4]:
import re
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T

import monai
from monai.networks.nets.unet import UNet
from monai.losses import DiceLoss

from model import SelectionNet
from train import *

from scipy.stats import rankdata

root_dir = "./data/malignant"
regex = re.compile(r"\d+")

def path_list_ranking(path_list):
    return sorted(path_list, key=lambda i: int(regex.findall(i)[0]))

def data_path_list(root_dir):
    path = Path(root_dir)
    all_paths = [str(image_path) for image_path in path.glob("*.png")]
    imgs_list = [img_path for img_path in all_paths if "_mask" not in img_path]
    labels_list = [label_path for label_path in all_paths if "_mask" in label_path]
    return path_list_ranking(imgs_list), path_list_ranking(labels_list)


imgs_list, labels_list = data_path_list(root_dir)


meta_train_list, meta_val_list, meta_train_label_list, meta_val_label_list = train_test_split(imgs_list, labels_list, test_size=120, train_size=90, random_state=25)
meta_val_list, holdout_list, meta_val_label_list, holdout_label_list = train_test_split(meta_val_list, meta_val_label_list, test_size=30, train_size=90, random_state=25)

print(len(holdout_list))

class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, images_list :list, 
                 labels_list :list, 
                 img_transform :T=None ,
                 label_transform :T=None):
        super().__init__()
        
        assert len(images_list) == len(labels_list) # images labels have numbers doesn't match
        self.images_list = images_list
        self.labels_list = labels_list
        self.img_transform = img_transform
        self.label_transform = label_transform
        
        
    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, index):
        image = Image.open(self.images_list[index])
        label = Image.open(self.labels_list[index])
        
        return self.img_transform(image), self.label_transform(label)

    
    
basic_transform = T.Compose([
    T.Resize([324,324]),
#     T.Grayscale(1),
#     T.GaussianBlur(kernel_size=(3, 7), sigma=(0.1, 5)),
    T.ToTensor(),
#     T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

basic_label_transform = T.Compose([
    T.Resize([324,324]),
    T.ToTensor()
])

Holdout_test_dataset = Dataset(holdout_list, holdout_label_list, basic_transform, basic_label_transform)

test_loader = DataLoader(Holdout_test_dataset, batch_size=10, shuffle=True, drop_last=True)


device = 'cuda:2'
nick_name = 'justmmd_true'
Seg_model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=(64, 128, 256, 512, 1024),
    strides=(2, 2, 1, 1),
    num_res_units=3,
    dropout = 0.2
)
Seg_model.load_state_dict(torch.load("/home/xiangcen/ipmi2088/before_meeting/Seg_modelfix_v2.pt", map_location=device))
Seg_model.to(device)

Sel_model = SelectionNet(
    DS_dim_list=[3, 64, 256, 512, 1024, 2048], 
    num_resblock=2, 
    transformer_input_dim=2048, 
    num_head=8,
    dropout=0.,
    num_transformer=6
)

Sel_model.load_state_dict(torch.load("/home/xiangcen/ipmi2088/sel_last_exp.pt", map_location=device))
Sel_model.to(device)


def metric_sel_net(
        seg_model, 
        sel_model, 
        holdout_loader,
        device='cpu'
    ):
    # remember the loader should drop the last batch to prevent differenct sequence number in the last batch
    seg_model.eval()
    seg_model.to(device)
    sel_model.eval()
    sel_model.to(device)
    step = 0.
    performance_all = 0.
    selection_list = []
    performance_list = []
    for img, label in holdout_loader:
        img, label = img.to(device), label.to(device)
        # forward pass and calculate the selection
        
        
        with torch.no_grad():
            selection_output = sel_model(img)
            # calculate the performance of the sequence
            o = seg_model(img)
            performance = dice_metric(o, label, sigmoid=True, mean=False)

            # print(performance, selection_output)

            selection_list.append(selection_output)
            performance_list.append(performance)

        
        step += 1
        performance_all += performance.mean().item()
        


    return performance_all/step, selection_list, performance_list
# _, sel, per = metric_sel_net(Seg_model, Sel_model, test_loader)
# _, sel, per

30


In [5]:

list = []
for j in range(10):
    _, sel, per = metric_sel_net(Seg_model, Sel_model, test_loader, device=device)
    for i in range(3):
        batch_index = sel[i]
        performance = per[i]
        print(batch_index)
        i_max = torch.argmax(batch_index).item()
        i_min = torch.argmin(batch_index).item()

        print(i_max, i_min)

        mean_performance = performance.mean()
        selected_performance = (performance[i_max] + performance[i_min]) / 2
        random_performance = performance[torch.randint(0, 10, (2, ))].mean()


        selected_difference, random_difference = (torch.abs(mean_performance - selected_performance), torch.abs(mean_performance - random_performance))

        list.append(torch.tensor([selected_difference, random_difference]))
result = torch.stack(list)

tensor([[0.7544],
        [1.0472],
        [0.2988],
        [0.4311],
        [0.3873],
        [0.7385],
        [0.3653],
        [0.3786],
        [0.7500],
        [4.5321]], device='cuda:2')
9 2
tensor([[2.6252],
        [0.2175],
        [0.1788],
        [3.7890],
        [1.4999],
        [0.2564],
        [0.1762],
        [0.8529],
        [1.3368],
        [0.2880]], device='cuda:2')
3 6
tensor([[0.3313],
        [0.2932],
        [0.1885],
        [0.1366],
        [0.7887],
        [4.8632],
        [3.2193],
        [0.3008],
        [0.2197],
        [0.1521]], device='cuda:2')
5 3
tensor([[3.6423],
        [0.2032],
        [1.6153],
        [0.1969],
        [0.1165],
        [0.2479],
        [2.2339],
        [0.1929],
        [0.1506],
        [3.0581]], device='cuda:2')
0 4
tensor([[0.3106],
        [2.4083],
        [0.6018],
        [0.4594],
        [0.5158],
        [3.3390],
        [0.6873],
        [0.3083],
        [0.4216],
        [0.6776]], device='cud

In [6]:
torch.mean(result, dim=0), torch.std(result, dim=0)

(tensor([0.0563, 0.0872]), tensor([0.0383, 0.0541]))