In [1]:
import os
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
import pandas as pd
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
# load the model
from models.resnet_skin import network
# default settings
num_classes = 7
model = network(name = 'resnet50', num_classes = num_classes)

In [11]:
# load the model parameters
model_ckpt = './checkpoint/HAM_CE_None_0_cosine/ckpt.pth.tar'
gpu = 0
model = model.cuda(gpu)
checkpoint = torch.load(model_ckpt, map_location = 'cuda:' + str(gpu))
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [12]:
# prepare the test data.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
ResizeTest = transforms.Resize(256)

transform_val = transforms.Compose([
        ResizeTest,
        transforms.ToTensor(),
        normalize,
    ])

data_folder = './skinlesiondatasets/SkinLesionDatasets/'
train_dataset = datasets.ImageFolder(root=os.path.join(data_folder, 'HAMtrain'), transform=transform_val)
val_dataset = datasets.ImageFolder(root=os.path.join(data_folder, 'HAMtest'), transform=transform_val)

val_dataset.class_to_idx = train_dataset.class_to_idx

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=True)

In [13]:
cols_names_classes = ['class_' + str(i) for i in range(0,num_classes)]
cols_names_logits = ['logit_' + str(i) for i in range(0, num_classes)]
cols_names_targets = ['target_' + str(i) for i in range(0, num_classes)]

In [14]:
cifarresultsdir = './skinresults/'
model.eval()
logits = []
preds = []
targets = []
for i, (input, target) in enumerate(tqdm(val_loader)):
    input = input.cuda(gpu, non_blocking = True)
    logits_test = model(input)
    preds_test = F.softmax(logits_test, dim = 1)
    targets_test = F.one_hot(target, num_classes = num_classes)
    logits.append(logits_test.cpu().detach())
    preds.append(preds_test.cpu().detach())
    targets.append(targets_test)
    
logits = torch.cat(logits, dim=0)
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
    
df = pd.DataFrame(data=preds.numpy(), columns=cols_names_classes)
df_logits = pd.DataFrame(data=logits.numpy(), columns=cols_names_logits)
df_targets = pd.DataFrame(data=targets.numpy(), columns=cols_names_targets)
df = pd.concat([df, df_logits, df_targets], axis=1)
df.to_csv(os.path.join(cifarresultsdir, 'predictions_val.csv'), index=False)

100%|████████████████████████| 3005/3005 [05:57<00:00,  8.41it/s]


In [15]:
# on training set.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
ResizeTest = transforms.Resize(256)

transform_val = transforms.Compose([
        ResizeTest,
        transforms.ToTensor(),
        normalize,
    ])

data_folder = './skinlesiondatasets/SkinLesionDatasets/'
train_dataset = datasets.ImageFolder(root=os.path.join(data_folder, 'HAMtrain'), transform=transform_val)
val_dataset = datasets.ImageFolder(root=os.path.join(data_folder, 'HAMtest'), transform=transform_val)

val_dataset.class_to_idx = train_dataset.class_to_idx

val_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=True)

In [16]:
cols_names_classes = ['class_' + str(i) for i in range(0,num_classes)]
cols_names_logits = ['logit_' + str(i) for i in range(0, num_classes)]
cols_names_targets = ['target_' + str(i) for i in range(0, num_classes)]

In [17]:
cifarresultsdir = './skinresults/'
model.eval()
logits = []
preds = []
targets = []
for i, (input, target) in enumerate(tqdm(val_loader)):
    input = input.cuda(gpu, non_blocking = True)
    logits_test = model(input)
    preds_test = F.softmax(logits_test, dim = 1)
    targets_test = F.one_hot(target, num_classes = num_classes)
    logits.append(logits_test.cpu().detach())
    preds.append(preds_test.cpu().detach())
    targets.append(targets_test)
    
logits = torch.cat(logits, dim=0)
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
    
df = pd.DataFrame(data=preds.numpy(), columns=cols_names_classes)
df_logits = pd.DataFrame(data=logits.numpy(), columns=cols_names_logits)
df_targets = pd.DataFrame(data=targets.numpy(), columns=cols_names_targets)
df = pd.concat([df, df_logits, df_targets], axis=1)
df.to_csv(os.path.join(cifarresultsdir, 'predictions_train.csv'), index=False)

100%|████████████████████████| 7010/7010 [07:48<00:00, 14.95it/s]


In [None]:
# Other datasets here
datasetlist = ['BCN', 'D7P', 'MSK', 'PH2', 'SON', 'UDA', 'VIE']
model.eval()
cifarresultsdir = './skinresults/'
for datasetn in datasetlist:
    
    print('Processing dataset ' + datasetn)
    
    csvsavename = 'predictions_test_' + datasetn +  '.csv'
    
    val_dataset = datasets.ImageFolder(root=os.path.join(data_folder, datasetn), transform=transform_val)

    val_dataset.class_to_idx = train_dataset.class_to_idx
    
    val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=1, shuffle=False,
            num_workers=0, pin_memory=True)
    
    logits = []
    preds = []
    targets = []
    for i, (input, target) in enumerate(tqdm(val_loader)):
        input = input.cuda(gpu, non_blocking = True)
        logits_test = model(input)
        preds_test = F.softmax(logits_test, dim = 1)
        targets_test = F.one_hot(target, num_classes = num_classes)
        logits.append(logits_test.cpu().detach())
        preds.append(preds_test.cpu().detach())
        targets.append(targets_test)

    logits = torch.cat(logits, dim=0)
    preds = torch.cat(preds, dim=0)
    targets = torch.cat(targets, dim=0)

    df = pd.DataFrame(data=preds.numpy(), columns=cols_names_classes)
    df_logits = pd.DataFrame(data=logits.numpy(), columns=cols_names_logits)
    df_targets = pd.DataFrame(data=targets.numpy(), columns=cols_names_targets)
    df = pd.concat([df, df_logits, df_targets], axis=1)
    df.to_csv(os.path.join(cifarresultsdir, csvsavename), index=False)
    

In [99]:
datasetlist = ['BCN', 'D7P', 'MSK', 'PH2', 'SON', 'UDA', 'VIE']
val_dataset = datasets.ImageFolder(root=os.path.join(data_folder, datasetlist[5]), transform=transform_val)
print(train_dataset.class_to_idx)
print(val_dataset.class_to_idx)

{'actinic keratosis': 0, 'basal cell carcinoma': 1, 'dermatofibroma': 2, 'melanoma': 3, 'nevus': 4, 'pigmented benign keratosis': 5, 'vascular lesion': 6}
{'basal cell carcinoma': 0, 'dermatofibroma': 1, 'melanoma': 2, 'nevus': 3, 'pigmented benign keratosis': 4}


In [94]:
# corrupted datasets here
cifarresultsdir = './skinresults/'
data_folder = './skinlesiondatasets/SkinLeisionDatasets_C/'

datasetlist = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 
               'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 
               'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 
               'snow', 'spatter', 'speckle_noise', 'zoom_blur']
severitylist = [1, 2, 3, 4, 5]

model.eval()
for datasetn in datasetlist:
    
    for severityindex in severitylist:
    
        print('Processing dataset ' + datasetn + ' severity ' + str(severityindex))

        csvsavename = 'predictions_test_' + datasetn + '_' + str(severityindex) + '.csv'

        val_dataset = datasets.ImageFolder(root=os.path.join(data_folder, datasetn, str(severityindex)), transform=transform_val)
        
        val_dataset.class_to_idx = train_dataset.class_to_idx

        val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=1, shuffle=False,
                num_workers=0, pin_memory=True)

        logits = []
        preds = []
        targets = []
        for i, (input, target) in enumerate(tqdm(val_loader)):
            input = input.cuda(gpu, non_blocking = True)
            logits_test = model(input)
            preds_test = F.softmax(logits_test, dim = 1)
            targets_test = F.one_hot(target, num_classes = num_classes)
            logits.append(logits_test.cpu().detach())
            preds.append(preds_test.cpu().detach())
            targets.append(targets_test)

        logits = torch.cat(logits, dim=0)
        preds = torch.cat(preds, dim=0)
        targets = torch.cat(targets, dim=0)

        df = pd.DataFrame(data=preds.numpy(), columns=cols_names_classes)
        df_logits = pd.DataFrame(data=logits.numpy(), columns=cols_names_logits)
        df_targets = pd.DataFrame(data=targets.numpy(), columns=cols_names_targets)
        df = pd.concat([df, df_logits, df_targets], axis=1)
        df.to_csv(os.path.join(cifarresultsdir, csvsavename), index=False)
    

Processing dataset brightness severity 1


100%|███████████████| 3005/3005 [00:41<00:00, 72.97it/s]


Processing dataset brightness severity 2


100%|███████████████| 3005/3005 [02:24<00:00, 20.81it/s]


Processing dataset brightness severity 3


100%|███████████████| 3005/3005 [07:02<00:00,  7.11it/s]


Processing dataset brightness severity 4


100%|███████████████| 3005/3005 [05:56<00:00,  8.44it/s]


Processing dataset brightness severity 5


100%|███████████████| 3005/3005 [06:02<00:00,  8.29it/s]


Processing dataset contrast severity 1


100%|███████████████| 3005/3005 [05:30<00:00,  9.10it/s]


Processing dataset contrast severity 2


100%|███████████████| 3005/3005 [04:37<00:00, 10.84it/s]


Processing dataset contrast severity 3


100%|███████████████| 3005/3005 [04:40<00:00, 10.71it/s]


Processing dataset contrast severity 4


100%|███████████████| 3005/3005 [04:30<00:00, 11.11it/s]


Processing dataset contrast severity 5


100%|███████████████| 3005/3005 [03:39<00:00, 13.68it/s]


Processing dataset defocus_blur severity 1


100%|███████████████| 3005/3005 [05:49<00:00,  8.60it/s]


Processing dataset defocus_blur severity 2


100%|███████████████| 3005/3005 [06:49<00:00,  7.33it/s]


Processing dataset defocus_blur severity 3


100%|███████████████| 3005/3005 [06:51<00:00,  7.30it/s]


Processing dataset defocus_blur severity 4


100%|███████████████| 3005/3005 [05:38<00:00,  8.87it/s]


Processing dataset defocus_blur severity 5


100%|███████████████| 3005/3005 [05:29<00:00,  9.11it/s]


Processing dataset elastic_transform severity 1


100%|███████████████| 3005/3005 [07:48<00:00,  6.42it/s]


Processing dataset elastic_transform severity 2


100%|███████████████| 3005/3005 [06:28<00:00,  7.73it/s]


Processing dataset elastic_transform severity 3


100%|███████████████| 3005/3005 [05:47<00:00,  8.64it/s]


Processing dataset elastic_transform severity 4


100%|███████████████| 3005/3005 [05:35<00:00,  8.97it/s]


Processing dataset elastic_transform severity 5


100%|███████████████| 3005/3005 [05:41<00:00,  8.79it/s]


Processing dataset fog severity 1


100%|███████████████| 3005/3005 [05:23<00:00,  9.30it/s]


Processing dataset fog severity 2


100%|███████████████| 3005/3005 [05:44<00:00,  8.73it/s]


Processing dataset fog severity 3


100%|███████████████| 3005/3005 [05:20<00:00,  9.37it/s]


Processing dataset fog severity 4


100%|███████████████| 3005/3005 [05:01<00:00,  9.96it/s]


Processing dataset fog severity 5


100%|███████████████| 3005/3005 [05:16<00:00,  9.48it/s]


Processing dataset frost severity 1


100%|███████████████| 3005/3005 [06:16<00:00,  7.98it/s]


Processing dataset frost severity 2


100%|███████████████| 3005/3005 [06:08<00:00,  8.16it/s]


Processing dataset frost severity 3


100%|███████████████| 3005/3005 [06:18<00:00,  7.94it/s]


Processing dataset frost severity 4


100%|███████████████| 3005/3005 [06:03<00:00,  8.27it/s]


Processing dataset frost severity 5


100%|███████████████| 3005/3005 [06:25<00:00,  7.79it/s]


Processing dataset gaussian_blur severity 1


100%|███████████████| 3005/3005 [05:34<00:00,  8.98it/s]


Processing dataset gaussian_blur severity 2


100%|███████████████| 3005/3005 [04:51<00:00, 10.29it/s]


Processing dataset gaussian_blur severity 3


100%|███████████████| 3005/3005 [04:42<00:00, 10.64it/s]


Processing dataset gaussian_blur severity 4


100%|███████████████| 3005/3005 [04:56<00:00, 10.15it/s]


Processing dataset gaussian_blur severity 5


100%|███████████████| 3005/3005 [04:19<00:00, 11.58it/s]


Processing dataset gaussian_noise severity 1


100%|███████████████| 3005/3005 [06:04<00:00,  8.25it/s]


Processing dataset gaussian_noise severity 2


100%|███████████████| 3005/3005 [06:39<00:00,  7.52it/s]


Processing dataset gaussian_noise severity 3


100%|███████████████| 3005/3005 [09:17<00:00,  5.39it/s]


Processing dataset gaussian_noise severity 4


100%|███████████████| 3005/3005 [08:04<00:00,  6.20it/s]


Processing dataset gaussian_noise severity 5


100%|███████████████| 3005/3005 [08:59<00:00,  5.57it/s]


Processing dataset glass_blur severity 1


100%|███████████████| 3005/3005 [04:46<00:00, 10.48it/s]


Processing dataset glass_blur severity 2


100%|███████████████| 3005/3005 [04:54<00:00, 10.20it/s]


Processing dataset glass_blur severity 3


100%|███████████████| 3005/3005 [03:32<00:00, 14.12it/s]


Processing dataset glass_blur severity 4


100%|███████████████| 3005/3005 [03:01<00:00, 16.59it/s]


Processing dataset glass_blur severity 5


100%|███████████████| 3005/3005 [02:59<00:00, 16.71it/s]


Processing dataset impulse_noise severity 1


100%|███████████████| 3005/3005 [02:59<00:00, 16.73it/s]


Processing dataset impulse_noise severity 2


100%|███████████████| 3005/3005 [04:26<00:00, 11.27it/s]


Processing dataset impulse_noise severity 3


100%|███████████████| 3005/3005 [04:34<00:00, 10.96it/s]


Processing dataset impulse_noise severity 4


100%|███████████████| 3005/3005 [04:37<00:00, 10.83it/s]


Processing dataset impulse_noise severity 5


100%|███████████████| 3005/3005 [05:25<00:00,  9.24it/s]


Processing dataset jpeg_compression severity 1


100%|███████████████| 3005/3005 [03:00<00:00, 16.64it/s]


Processing dataset jpeg_compression severity 2


100%|███████████████| 3005/3005 [02:19<00:00, 21.58it/s]


Processing dataset jpeg_compression severity 3


100%|███████████████| 3005/3005 [01:41<00:00, 29.50it/s]


Processing dataset jpeg_compression severity 4


100%|███████████████| 3005/3005 [01:45<00:00, 28.47it/s]


Processing dataset jpeg_compression severity 5


100%|███████████████| 3005/3005 [01:43<00:00, 28.91it/s]


Processing dataset motion_blur severity 1


100%|███████████████| 3005/3005 [02:06<00:00, 23.83it/s]


Processing dataset motion_blur severity 2


100%|███████████████| 3005/3005 [01:57<00:00, 25.51it/s]


Processing dataset motion_blur severity 3


100%|███████████████| 3005/3005 [01:53<00:00, 26.51it/s]


Processing dataset motion_blur severity 4


100%|███████████████| 3005/3005 [01:41<00:00, 29.51it/s]


Processing dataset motion_blur severity 5


100%|███████████████| 3005/3005 [01:53<00:00, 26.55it/s]


Processing dataset pixelate severity 1


100%|███████████████| 3005/3005 [02:02<00:00, 24.59it/s]


Processing dataset pixelate severity 2


100%|███████████████| 3005/3005 [01:55<00:00, 26.10it/s]


Processing dataset pixelate severity 3


100%|███████████████| 3005/3005 [01:55<00:00, 26.13it/s]


Processing dataset pixelate severity 4


100%|███████████████| 3005/3005 [01:55<00:00, 25.95it/s]


Processing dataset pixelate severity 5


100%|███████████████| 3005/3005 [01:56<00:00, 25.86it/s]


Processing dataset saturate severity 1


100%|███████████████| 3005/3005 [02:16<00:00, 21.94it/s]


Processing dataset saturate severity 2


100%|███████████████| 3005/3005 [02:05<00:00, 23.86it/s]


Processing dataset saturate severity 3


100%|███████████████| 3005/3005 [02:03<00:00, 24.34it/s]


Processing dataset saturate severity 4


100%|███████████████| 3005/3005 [01:57<00:00, 25.55it/s]


Processing dataset saturate severity 5


100%|███████████████| 3005/3005 [02:02<00:00, 24.43it/s]


Processing dataset shot_noise severity 1


100%|███████████████| 3005/3005 [02:00<00:00, 24.87it/s]


Processing dataset shot_noise severity 2


100%|███████████████| 3005/3005 [02:28<00:00, 20.20it/s]


Processing dataset shot_noise severity 3


100%|███████████████| 3005/3005 [02:47<00:00, 17.93it/s]


Processing dataset shot_noise severity 4


100%|███████████████| 3005/3005 [02:53<00:00, 17.34it/s]


Processing dataset shot_noise severity 5


100%|███████████████| 3005/3005 [02:37<00:00, 19.05it/s]


Processing dataset snow severity 1


100%|███████████████| 3005/3005 [02:06<00:00, 23.83it/s]


Processing dataset snow severity 2


100%|███████████████| 3005/3005 [02:11<00:00, 22.85it/s]


Processing dataset snow severity 3


100%|███████████████| 3005/3005 [02:12<00:00, 22.70it/s]


Processing dataset snow severity 4


100%|███████████████| 3005/3005 [02:10<00:00, 23.08it/s]


Processing dataset snow severity 5


100%|███████████████| 3005/3005 [02:00<00:00, 25.02it/s]


Processing dataset spatter severity 1


100%|███████████████| 3005/3005 [01:49<00:00, 27.37it/s]


Processing dataset spatter severity 2


100%|███████████████| 3005/3005 [01:55<00:00, 26.13it/s]


Processing dataset spatter severity 3


100%|███████████████| 3005/3005 [01:52<00:00, 26.62it/s]


Processing dataset spatter severity 4


100%|███████████████| 3005/3005 [02:03<00:00, 24.34it/s]


Processing dataset spatter severity 5


100%|███████████████| 3005/3005 [02:12<00:00, 22.62it/s]


Processing dataset speckle_noise severity 1


100%|███████████████| 3005/3005 [01:57<00:00, 25.58it/s]


Processing dataset speckle_noise severity 2


100%|███████████████| 3005/3005 [02:08<00:00, 23.47it/s]


Processing dataset speckle_noise severity 3


100%|███████████████| 3005/3005 [02:48<00:00, 17.78it/s]


Processing dataset speckle_noise severity 4


100%|███████████████| 3005/3005 [02:33<00:00, 19.57it/s]


Processing dataset speckle_noise severity 5


100%|███████████████| 3005/3005 [02:55<00:00, 17.13it/s]


Processing dataset zoom_blur severity 1


100%|███████████████| 3005/3005 [01:47<00:00, 27.99it/s]


Processing dataset zoom_blur severity 2


100%|███████████████| 3005/3005 [01:49<00:00, 27.44it/s]


Processing dataset zoom_blur severity 3


100%|███████████████| 3005/3005 [01:38<00:00, 30.54it/s]


Processing dataset zoom_blur severity 4


100%|███████████████| 3005/3005 [01:38<00:00, 30.52it/s]


Processing dataset zoom_blur severity 5


100%|███████████████| 3005/3005 [01:42<00:00, 29.22it/s]
