In [1]:
import os
import torch

import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import time
import copy
import numpy as np
import random
import pydicom
import nibabel as nib
import matplotlib.pyplot as plt
import pandas as pd
import datetime

In [2]:
from torch.optim import lr_scheduler
from torchvision import datasets
from efficientnet_pytorch import EfficientNet
from sklearn.metrics import roc_auc_score, confusion_matrix, accuracy_score
from sklearn.metrics import classification_report, roc_curve
import logging

from pathlib import Path
from PIL import Image
import SimpleITK as sitk
from numpy import asarray
from math import sqrt
from scipy.special import ndtri
from numpy import argmax

In [3]:
class Medical_dataset(torch.utils.data.Dataset):

    def __init__(self, first_layer_files, second_layer_files, third_layer_files, labels, transforms):

        self.first_layer_files = first_layer_files
        self.second_layer_files = second_layer_files
        self.third_layer_files = third_layer_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return min(len(self.first_layer_files), len(self.second_layer_files), len(self.third_layer_files))
    
    def __getitem__(self, idx):

        # first layer will be dcm files
        img_path_dcm = self.first_layer_files[idx]
        ds_A = pydicom.read_file(img_path_dcm)
        img_A = np.array(ds_A.pixel_array, dtype=np.float32)

        min, max = np.min(img_A), np.max(img_A)
        img_A = 255 - ((img_A - min) / (max - min)) * 255
        PIL_image_A = Image.fromarray(img_A.astype(np.uint8))

        # second layer will be nii files
        img_path = self.second_layer_files[idx]
        ds_B = nib.load(img_path)
        img_B = np.array(ds_B.get_fdata(), dtype=np.float32)

        if len(img_B.shape) == 4:
            img_B = img_B[:,:,0,0]
        else:
            min, max = np.min(img_B), np.max(img_B)
            img_B = ((img_B - min) / (max - min)) * 255
            PIL_image_B = Image.fromarray(np.transpose(img_B).astype(np.uint8))

        # third layer will be nii files
        img_path = self.third_layer_files[idx]
        ds_C = nib.load(img_path)
        img_C = np.array(ds_C.get_fdata(), dtype=np.float32)

        if len(img_C.shape) == 4:
            img_C = img_C[:,:,0,0]
        else:
            min, max = np.min(img_C), np.max(img_C)
            img_C = ((img_C - min) / (max - min)) * 255
            PIL_image_C = Image.fromarray(np.transpose(img_C).astype(np.uint8))

        return self.transforms(PIL_image_A), self.transforms(PIL_image_B), self.transforms(PIL_image_C), self.labels[idx]

In [4]:
resol = 1024

train_transforms = transforms.Compose([
    transforms.Resize((resol, resol)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

valid_transforms = transforms.Compose([
    transforms.Resize((resol, resol)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [5]:
# make dataset for train and valid based on class above
first_layer_files_list = []
fl_dir = '/data4/vindr-cxr/train/'
files = sorted(os.listdir(fl_dir))
for file in files:
    first_layer_files_list.append(fl_dir + file)

second_layer_files_list = []
sl_dir = '/data4/20231130_tisepx_vinbig/train_lung/'
files = sorted(os.listdir(sl_dir))
for file in files:
    second_layer_files_list.append(sl_dir + file)

third_layer_files_list = []
tl_dir = '/data4/20231130_tisepx_vinbig/train_pulmonaryLung/'
files = sorted(os.listdir(tl_dir))
for file in files:
    third_layer_files_list.append(tl_dir + file)

img_list = sorted(os.listdir(fl_dir))
id_list = []
for img in img_list:
    id_list.append(img.split('.')[0])

labels = []
metadata_dir = '/data4/vindr-cxr/train.csv'
df = pd.read_csv(metadata_dir)

for i in id_list:
    labels.append(df.loc[df['image_id'] == i, 'class_id'].iloc[0])

# count numver of values in each class
from collections import Counter
print(Counter(labels))
print(len(labels))

Counter({14: 10606, 0: 1127, 3: 878, 13: 522, 11: 496, 7: 251, 10: 241, 9: 224, 8: 221, 6: 134, 5: 111, 2: 91, 4: 46, 1: 30, 12: 22})
15000


In [6]:
# split train and valid
length = len(first_layer_files_list)
np.random.seed(555)
indices = np.arange(length)
np.random.shuffle(indices)

test_split = int(np.floor(0.2 * length))
val_split = int(np.floor(0.1 * length))
test_indices, val_indices, train_indices = indices[:test_split], indices[test_split:test_split+val_split], indices[test_split+val_split:]

train_x1 = [first_layer_files_list[i] for i in train_indices]
train_x2 = [second_layer_files_list[i] for i in train_indices]
train_x3 = [third_layer_files_list[i] for i in train_indices]
train_y = [labels[i] for i in train_indices]

val_x1 = [first_layer_files_list[i] for i in val_indices]
val_x2 = [second_layer_files_list[i] for i in val_indices]
val_x3 = [third_layer_files_list[i] for i in val_indices]
val_y = [labels[i] for i in val_indices]

test_x1 = [first_layer_files_list[i] for i in test_indices]
test_x2 = [second_layer_files_list[i] for i in test_indices]
test_x3 = [third_layer_files_list[i] for i in test_indices]
test_y = [labels[i] for i in test_indices]

print(f'num of train dataset: {len(train_x1)}, valid dataset: {len(val_x2)}, test dataset: {len(test_x3)}')

num of train dataset: 10500, valid dataset: 1500, test dataset: 3000


In [7]:
# Setting up training enviornment
from efficientnet_pytorch import EfficientNet
model_name = 'efficientnet-b5'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNet.from_pretrained(model_name,in_channels=3, num_classes=15)
# use multiple gpu
model = nn.DataParallel(model)
model.to(device)

loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), 
                         lr = 0.05,
                         momentum=0.9,
                         weight_decay=1e-4)
lmbda = lambda epoch: 0.98739
scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

batch_size  = 32
random_seed = 555
random.seed(random_seed)
torch.manual_seed(random_seed)

Loaded pretrained weights for efficientnet-b5


<torch._C.Generator at 0x7f17b8b505f0>

In [8]:
# train and valid dataset
train_dataset = Medical_dataset(train_x1, train_x2, train_x3, train_y, train_transforms)
valid_dataset = Medical_dataset(val_x1, val_x2, val_x3, val_y, valid_transforms)
test_dataset = Medical_dataset(test_x1, test_x2, test_x3, test_y, valid_transforms)

# dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [29]:
# train and valid for 3 channels EfficientNet
best_metric = 100000

epoch_loss_values = []
metric_values = []
auroc_values = []
acc_values = []

for epoch in range(1, 100):
    # train
    start_time = time.time()
    epoch_loss = 0
    epoch_samples = 0
    step = 0
    model.train()
    for batch_data in train_loader:
        step += 1
        # unsquezee each channel
        x1, x2, x3, y = batch_data[0].cpu(), batch_data[1].cpu(), batch_data[2].cpu(), batch_data[3].cuda()
        inputs = torch.cat([x1, x2, x3], axis=1)

        # make y to one hot encoding
        y = F.one_hot(y, num_classes=15).cuda().float()

        optimizer.zero_grad()
        y_hat = model(inputs)
        
        loss = loss_function(y_hat, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * y.size(0)
        epoch_samples += y.size(0)
    scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f'epoch {epoch} loss is {epoch_loss:.4f}')

    # make checkpoint
    if epoch % 10 == 0:
        torch.save(model.state_dict(), '/data4/vin_model_weights/' + '_epoch_' + str(epoch) + '_checkpoint.pth')
        print('saved model for checkpoint')

    # validation
    model.eval()
    val_loss_values = []
    val_loss = 0
    with torch.no_grad():
        y_pred = torch.tensor([], dtype=torch.float32, device=device)
        y_labels = torch.tensor([], dtype=torch.float32, device=device)

        for val_data in valid_loader:
            x1, x2, x3, y = val_data[0].cpu(), val_data[1].cpu(), val_data[2].cpu(), val_data[3].cuda()
            val_images = torch.cat((x1, x2, x3), 1).float()

            val_outputs = model(val_images)
            val_labels = F.one_hot(y, num_classes=15).cuda().float()
            
            val_loss = loss_function(val_outputs, val_labels)

            y_pred = torch.cat([y_pred, torch.sigmoid(val_outputs)], dim=0)
            y_labels = torch.cat([y_labels, val_labels], dim=0)

            val_loss_values.append(val_loss.item())

        val_loss = np.sum(val_loss_values) / len(val_loss_values)
        y_pred, y_labels = y_pred.cpu(), y_labels.cpu()
        result = roc_auc_score(y_labels, y_pred)
        acc_metric = accuracy_score(y_labels, (y_pred>0.5).float())

        val_result = np.mean(val_loss_values)
        metric_values.append(val_result)
        
        auroc_values.append(result)
        acc_values.append(acc_metric)
        
        if val_result < best_metric:
            best_metric = val_result
            torch.save(model.state_dict(), '/data4/vin_model_weights/' + '_epoch_' + str(epoch) + '_best_metric_model.pth')
            print('saved new best model')

        print(f'epoch {epoch} val loss is {val_result:.4f}')
        print(f'epoch {epoch} val auroc is {result:.4f}')
        print(f'epoch {epoch} val acc is {acc_metric:.4f}')

    print(f'epoch {epoch} took {(time.time()-start_time)/60:.2f} min')

epoch 1 loss is 2.8673
saved new best model
epoch 1 val loss is 0.0906
epoch 1 val auroc is 0.8493
epoch 1 val acc is 0.6440
epoch 1 took 117.28 min




epoch 2 loss is 2.7088
saved new best model
epoch 2 val loss is 0.0870
epoch 2 val auroc is 0.8701
epoch 2 val acc is 0.6553
epoch 2 took 117.09 min




epoch 3 loss is 2.5585
epoch 3 val loss is 0.0969
epoch 3 val auroc is 0.8747
epoch 3 val acc is 0.5800
epoch 3 took 117.82 min




epoch 4 loss is 2.4403
saved new best model
epoch 4 val loss is 0.0795
epoch 4 val auroc is 0.8791
epoch 4 val acc is 0.6627
epoch 4 took 116.77 min




epoch 5 loss is 2.3352
epoch 5 val loss is 0.0801
epoch 5 val auroc is 0.8801
epoch 5 val acc is 0.6680
epoch 5 took 117.30 min




epoch 6 loss is 2.2277
epoch 6 val loss is 0.0812
epoch 6 val auroc is 0.8770
epoch 6 val acc is 0.6627
epoch 6 took 116.37 min




epoch 7 loss is 2.1300
epoch 7 val loss is 0.0829
epoch 7 val auroc is 0.8866
epoch 7 val acc is 0.6480
epoch 7 took 116.47 min




epoch 8 loss is 2.0270
epoch 8 val loss is 0.0811
epoch 8 val auroc is 0.8888
epoch 8 val acc is 0.6680
epoch 8 took 117.65 min




epoch 9 loss is 1.9772
epoch 9 val loss is 0.0829
epoch 9 val auroc is 0.8934
epoch 9 val acc is 0.6640
epoch 9 took 116.70 min




epoch 10 loss is 1.9070
epoch 10 val loss is 0.0822
epoch 10 val auroc is 0.8990
epoch 10 val acc is 0.6853
epoch 10 took 116.60 min




epoch 11 loss is 1.8679
epoch 11 val loss is 0.0811
epoch 11 val auroc is 0.8948
epoch 11 val acc is 0.6787
epoch 11 took 116.57 min




epoch 12 loss is 1.8018
epoch 12 val loss is 0.0809
epoch 12 val auroc is 0.8973
epoch 12 val acc is 0.6973
epoch 12 took 116.99 min




epoch 13 loss is 1.7616
epoch 13 val loss is 0.0813
epoch 13 val auroc is 0.8945
epoch 13 val acc is 0.6913
epoch 13 took 116.34 min




epoch 14 loss is 1.7377
epoch 14 val loss is 0.0905
epoch 14 val auroc is 0.8964
epoch 14 val acc is 0.6900
epoch 14 took 116.77 min




epoch 15 loss is 1.6612
epoch 15 val loss is 0.0857
epoch 15 val auroc is 0.8928
epoch 15 val acc is 0.6927
epoch 15 took 116.51 min




epoch 16 loss is 1.6609
epoch 16 val loss is 0.0871
epoch 16 val auroc is 0.8953
epoch 16 val acc is 0.6993
epoch 16 took 116.26 min




epoch 17 loss is 1.5672
epoch 17 val loss is 0.0868
epoch 17 val auroc is 0.8962
epoch 17 val acc is 0.7140
epoch 17 took 115.96 min




epoch 18 loss is 1.4928
epoch 18 val loss is 0.1110
epoch 18 val auroc is 0.8860
epoch 18 val acc is 0.6980
epoch 18 took 117.47 min




epoch 19 loss is 1.4694
epoch 19 val loss is 0.0890
epoch 19 val auroc is 0.8929
epoch 19 val acc is 0.7147
epoch 19 took 116.55 min




epoch 20 loss is 1.3593
epoch 20 val loss is 0.0969
epoch 20 val auroc is 0.8907
epoch 20 val acc is 0.7093
epoch 20 took 117.12 min




epoch 21 loss is 1.2973
epoch 21 val loss is 0.1033
epoch 21 val auroc is 0.8815
epoch 21 val acc is 0.6700
epoch 21 took 116.15 min




epoch 22 loss is 1.2769
epoch 22 val loss is 0.1036
epoch 22 val auroc is 0.8614
epoch 22 val acc is 0.7007
epoch 22 took 116.60 min




epoch 23 loss is 1.2024
epoch 23 val loss is 0.1081
epoch 23 val auroc is 0.8818
epoch 23 val acc is 0.6580
epoch 23 took 116.32 min




epoch 24 loss is 1.1501
epoch 24 val loss is 0.0992
epoch 24 val auroc is 0.8880
epoch 24 val acc is 0.6973
epoch 24 took 116.44 min




epoch 25 loss is 1.1463
epoch 25 val loss is 0.1050
epoch 25 val auroc is 0.8845
epoch 25 val acc is 0.7133
epoch 25 took 116.54 min




epoch 26 loss is 1.0644
epoch 26 val loss is 0.1087
epoch 26 val auroc is 0.8779
epoch 26 val acc is 0.7020
epoch 26 took 115.87 min




epoch 27 loss is 1.0366
epoch 27 val loss is 0.1288
epoch 27 val auroc is 0.8618
epoch 27 val acc is 0.6300
epoch 27 took 116.39 min




epoch 28 loss is 0.9966
epoch 28 val loss is 0.1051
epoch 28 val auroc is 0.8709
epoch 28 val acc is 0.7167
epoch 28 took 116.87 min




epoch 29 loss is 0.9406
epoch 29 val loss is 0.1114
epoch 29 val auroc is 0.8719
epoch 29 val acc is 0.6987
epoch 29 took 116.20 min




epoch 30 loss is 0.9210
epoch 30 val loss is 0.1309
epoch 30 val auroc is 0.8325
epoch 30 val acc is 0.6393
epoch 30 took 116.56 min




epoch 31 loss is 0.9569
epoch 31 val loss is 0.1079
epoch 31 val auroc is 0.8761
epoch 31 val acc is 0.6967
epoch 31 took 116.00 min




epoch 32 loss is 0.8498
epoch 32 val loss is 0.1118
epoch 32 val auroc is 0.8738
epoch 32 val acc is 0.6767
epoch 32 took 116.02 min




epoch 33 loss is 0.8509
epoch 33 val loss is 0.1477
epoch 33 val auroc is 0.8512
epoch 33 val acc is 0.6160
epoch 33 took 115.90 min




epoch 34 loss is 0.8416
epoch 34 val loss is 0.1084
epoch 34 val auroc is 0.8807
epoch 34 val acc is 0.7040
epoch 34 took 116.12 min




epoch 35 loss is 0.7751
epoch 35 val loss is 0.1161
epoch 35 val auroc is 0.8723
epoch 35 val acc is 0.6973
epoch 35 took 116.21 min




epoch 36 loss is 0.7661
epoch 36 val loss is 0.1163
epoch 36 val auroc is 0.8572
epoch 36 val acc is 0.6720
epoch 36 took 116.30 min




epoch 37 loss is 0.7311
epoch 37 val loss is 0.1097
epoch 37 val auroc is 0.8798
epoch 37 val acc is 0.7053
epoch 37 took 115.84 min




epoch 38 loss is 0.6905
epoch 38 val loss is 0.1117
epoch 38 val auroc is 0.8683
epoch 38 val acc is 0.7053
epoch 38 took 116.06 min




epoch 39 loss is 0.6786
epoch 39 val loss is 0.1231
epoch 39 val auroc is 0.8518
epoch 39 val acc is 0.6747
epoch 39 took 116.04 min




epoch 40 loss is 0.6559
epoch 40 val loss is 0.1199
epoch 40 val auroc is 0.8772
epoch 40 val acc is 0.7060
epoch 40 took 116.32 min




epoch 41 loss is 0.7167
epoch 41 val loss is 0.1216
epoch 41 val auroc is 0.8537
epoch 41 val acc is 0.6647
epoch 41 took 115.79 min




epoch 42 loss is 0.6430
epoch 42 val loss is 0.1136
epoch 42 val auroc is 0.8762
epoch 42 val acc is 0.6933
epoch 42 took 116.38 min




epoch 43 loss is 0.6071
epoch 43 val loss is 0.1337
epoch 43 val auroc is 0.8439
epoch 43 val acc is 0.6600
epoch 43 took 116.25 min




epoch 44 loss is 0.6055
epoch 44 val loss is 0.1229
epoch 44 val auroc is 0.8636
epoch 44 val acc is 0.6580
epoch 44 took 115.93 min




epoch 45 loss is 0.6644
epoch 45 val loss is 0.1280
epoch 45 val auroc is 0.8202
epoch 45 val acc is 0.6140
epoch 45 took 116.21 min




epoch 46 loss is 0.6026
epoch 46 val loss is 0.1373
epoch 46 val auroc is 0.7982
epoch 46 val acc is 0.6100
epoch 46 took 116.32 min




epoch 47 loss is 0.5570
epoch 47 val loss is 0.1192
epoch 47 val auroc is 0.8575
epoch 47 val acc is 0.6687
epoch 47 took 115.87 min




epoch 48 loss is 0.5307
epoch 48 val loss is 0.1223
epoch 48 val auroc is 0.8710
epoch 48 val acc is 0.6920
epoch 48 took 116.13 min




epoch 49 loss is 0.4971
epoch 49 val loss is 0.1250
epoch 49 val auroc is 0.8381
epoch 49 val acc is 0.6287
epoch 49 took 116.14 min




epoch 50 loss is 0.5602
epoch 50 val loss is 0.1249
epoch 50 val auroc is 0.8137
epoch 50 val acc is 0.6320
epoch 50 took 116.34 min




epoch 51 loss is 0.5737
epoch 51 val loss is 0.1320
epoch 51 val auroc is 0.8191
epoch 51 val acc is 0.5933
epoch 51 took 117.46 min




epoch 52 loss is 0.5141
epoch 52 val loss is 0.1349
epoch 52 val auroc is 0.8215
epoch 52 val acc is 0.6013
epoch 52 took 117.48 min




epoch 53 loss is 0.5317
epoch 53 val loss is 0.1177
epoch 53 val auroc is 0.8627
epoch 53 val acc is 0.6733
epoch 53 took 116.91 min




epoch 54 loss is 0.4730
epoch 54 val loss is 0.1253
epoch 54 val auroc is 0.8589
epoch 54 val acc is 0.6467
epoch 54 took 116.74 min




epoch 55 loss is 0.4674
epoch 55 val loss is 0.1282
epoch 55 val auroc is 0.8521
epoch 55 val acc is 0.6547
epoch 55 took 117.61 min




epoch 56 loss is 0.4432
epoch 56 val loss is 0.1243
epoch 56 val auroc is 0.8538
epoch 56 val acc is 0.6500
epoch 56 took 116.90 min




epoch 57 loss is 0.4487
epoch 57 val loss is 0.1276
epoch 57 val auroc is 0.8298
epoch 57 val acc is 0.6613
epoch 57 took 117.30 min




epoch 58 loss is 0.4586
epoch 58 val loss is 0.1299
epoch 58 val auroc is 0.8350
epoch 58 val acc is 0.6160
epoch 58 took 117.30 min




epoch 59 loss is 0.4771
epoch 59 val loss is 0.1306
epoch 59 val auroc is 0.8185
epoch 59 val acc is 0.6020
epoch 59 took 117.28 min




epoch 60 loss is 0.5102
epoch 60 val loss is 0.1145
epoch 60 val auroc is 0.8500
epoch 60 val acc is 0.6627
epoch 60 took 118.20 min




epoch 61 loss is 0.4615
epoch 61 val loss is 0.1306
epoch 61 val auroc is 0.8316
epoch 61 val acc is 0.6360
epoch 61 took 117.83 min




epoch 62 loss is 0.4366
epoch 62 val loss is 0.1103
epoch 62 val auroc is 0.8745
epoch 62 val acc is 0.6647
epoch 62 took 118.29 min




epoch 63 loss is 0.3948
epoch 63 val loss is 0.1150
epoch 63 val auroc is 0.8610
epoch 63 val acc is 0.6593
epoch 63 took 117.22 min




epoch 64 loss is 0.3767
epoch 64 val loss is 0.1141
epoch 64 val auroc is 0.8720
epoch 64 val acc is 0.6560
epoch 64 took 117.00 min




epoch 65 loss is 0.3852
epoch 65 val loss is 0.1139
epoch 65 val auroc is 0.8530
epoch 65 val acc is 0.6660
epoch 65 took 117.44 min




epoch 66 loss is 0.4134
epoch 66 val loss is 0.1322
epoch 66 val auroc is 0.8505
epoch 66 val acc is 0.6047
epoch 66 took 117.47 min




epoch 67 loss is 0.5631
epoch 67 val loss is 0.1323
epoch 67 val auroc is 0.8113
epoch 67 val acc is 0.5800
epoch 67 took 117.86 min




epoch 68 loss is 0.4564
epoch 68 val loss is 0.1214
epoch 68 val auroc is 0.8310
epoch 68 val acc is 0.6260
epoch 68 took 117.42 min




epoch 69 loss is 0.4295
epoch 69 val loss is 0.1167
epoch 69 val auroc is 0.8284
epoch 69 val acc is 0.6687
epoch 69 took 117.40 min




epoch 70 loss is 0.3647
epoch 70 val loss is 0.1179
epoch 70 val auroc is 0.8518
epoch 70 val acc is 0.6720
epoch 70 took 117.10 min




epoch 71 loss is 0.3281
epoch 71 val loss is 0.1207
epoch 71 val auroc is 0.8457
epoch 71 val acc is 0.6140
epoch 71 took 117.56 min




epoch 72 loss is 0.3283
epoch 72 val loss is 0.1268
epoch 72 val auroc is 0.8483
epoch 72 val acc is 0.6800
epoch 72 took 117.80 min




epoch 73 loss is 0.3456
epoch 73 val loss is 0.1304
epoch 73 val auroc is 0.8194
epoch 73 val acc is 0.6280
epoch 73 took 117.64 min




epoch 74 loss is 0.3097
epoch 74 val loss is 0.1324
epoch 74 val auroc is 0.8266
epoch 74 val acc is 0.6613
epoch 74 took 116.05 min




epoch 75 loss is 0.2897
epoch 75 val loss is 0.1253
epoch 75 val auroc is 0.8046
epoch 75 val acc is 0.6620
epoch 75 took 116.37 min




epoch 76 loss is 0.3244
epoch 76 val loss is 0.1176
epoch 76 val auroc is 0.8280
epoch 76 val acc is 0.6627
epoch 76 took 116.05 min




epoch 77 loss is 0.3051
epoch 77 val loss is 0.1478
epoch 77 val auroc is 0.8099
epoch 77 val acc is 0.5813
epoch 77 took 116.30 min




epoch 78 loss is 0.3364
epoch 78 val loss is 0.1606
epoch 78 val auroc is 0.8024
epoch 78 val acc is 0.5700
epoch 78 took 116.57 min




epoch 79 loss is 0.3595
epoch 79 val loss is 0.1217
epoch 79 val auroc is 0.8459
epoch 79 val acc is 0.6727
epoch 79 took 116.19 min




epoch 80 loss is 0.2944
epoch 80 val loss is 0.1251
epoch 80 val auroc is 0.8288
epoch 80 val acc is 0.6353
epoch 80 took 116.03 min




epoch 81 loss is 0.3172
epoch 81 val loss is 0.1209
epoch 81 val auroc is 0.8462
epoch 81 val acc is 0.6767
epoch 81 took 116.17 min




epoch 82 loss is 0.3338
epoch 82 val loss is 0.1528
epoch 82 val auroc is 0.7914
epoch 82 val acc is 0.5687
epoch 82 took 116.86 min




epoch 83 loss is 0.2770
epoch 83 val loss is 0.1308
epoch 83 val auroc is 0.8345
epoch 83 val acc is 0.6733
epoch 83 took 116.00 min




epoch 84 loss is 0.2516
epoch 84 val loss is 0.1281
epoch 84 val auroc is 0.8189
epoch 84 val acc is 0.6613
epoch 84 took 118.08 min




epoch 85 loss is 0.3468
epoch 85 val loss is 0.1573
epoch 85 val auroc is 0.7972
epoch 85 val acc is 0.5880
epoch 85 took 116.80 min




epoch 86 loss is 0.3601
epoch 86 val loss is 0.1154
epoch 86 val auroc is 0.8522
epoch 86 val acc is 0.6840
epoch 86 took 116.79 min




epoch 87 loss is 0.3129
epoch 87 val loss is 0.1175
epoch 87 val auroc is 0.8418
epoch 87 val acc is 0.6653
epoch 87 took 116.27 min




epoch 88 loss is 0.2665
epoch 88 val loss is 0.1259
epoch 88 val auroc is 0.8348
epoch 88 val acc is 0.6780
epoch 88 took 116.42 min




epoch 89 loss is 0.2683
epoch 89 val loss is 0.1385
epoch 89 val auroc is 0.8226
epoch 89 val acc is 0.6473
epoch 89 took 117.00 min




epoch 90 loss is 0.2708
epoch 90 val loss is 0.1202
epoch 90 val auroc is 0.8308
epoch 90 val acc is 0.6853
epoch 90 took 115.97 min




epoch 91 loss is 0.2522
epoch 91 val loss is 0.1296
epoch 91 val auroc is 0.8056
epoch 91 val acc is 0.6267
epoch 91 took 116.86 min




epoch 92 loss is 0.3428
epoch 92 val loss is 0.1213
epoch 92 val auroc is 0.7979
epoch 92 val acc is 0.6720
epoch 92 took 117.21 min




epoch 93 loss is 0.2923
epoch 93 val loss is 0.1165
epoch 93 val auroc is 0.8307
epoch 93 val acc is 0.6607
epoch 93 took 116.41 min




epoch 94 loss is 0.2633
epoch 94 val loss is 0.1284
epoch 94 val auroc is 0.8424
epoch 94 val acc is 0.6573
epoch 94 took 116.28 min




In [12]:
metric_values = []
auroc_values = []
acc_values = []

model = EfficientNet.from_pretrained(model_name,in_channels=3, num_classes=15)
model = nn.DataParallel(model)
model.to(device)

model.load_state_dict(torch.load('/data4/vin_model_weights/_epoch_5_best_metric_model.pth'))
model.eval()

test_loss_values = []
test_loss = 0

with torch.no_grad():

    y_pred = torch.tensor([], dtype=torch.float32, device=device)
    y_labels = torch.tensor([], dtype=torch.float32, device=device)

    for test_data in test_loader:

        x1, x2, x3, y = test_data[0].cpu(), test_data[1].cpu(), test_data[2].cpu(), test_data[3].cuda()
        outputs = torch.cat([x1, x2, x3], axis=1)

        # make y to one hot encoding
        test_outputs = model(outputs)
        test_labels = F.one_hot(y, num_classes=15).cuda().float()

        test_loss = loss_function(test_outputs, test_labels)

        y_pred = torch.cat([y_pred, torch.sigmoid(test_outputs)], dim=0)
        y_labels = torch.cat([y_labels, test_labels], dim=0)

        test_loss_values.append(test_loss.item())

    test_loss = np.sum(test_loss_values) / len(test_loss_values)

    y_pred, y_labels = y_pred.cpu(), y_labels.cpu()
    result = roc_auc_score(y_labels, y_pred)
    acc_metric = accuracy_score(y_labels, (y_pred>0.5).float())

    test_result = np.mean(test_loss_values)
    metric_values.append(test_result)

    auroc_values.append(result)
    acc_values.append(acc_metric)

    print(f'test loss is {test_loss:.4f}')
    print(f'test auroc is {result:.4f}')
    print(f'test acc is {acc_metric:.4f}')

Loaded pretrained weights for efficientnet-b5




test loss is 0.0729
test auroc is 0.8853
test acc is 0.6927


In [87]:
# add grad-cam to model
import cv2
from torch.autograd import Function
from torchvision import models
from torchvision import utils
from PIL import Image
import matplotlib.cm as cm

class FeatureExtractor():
    """ Class for extracting activations and 
    registering gradients from targetted intermediate layers """

    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradients = []

    def save_gradient(self, grad):
        self.gradients.append(grad)

    def __call__(self, x):

        outputs = []
        self.gradients = []
        for name, module in self.model._modules.items():
            if name == 'fc':
                x = x.view(x.size(0), -1)
            x = module(x)
            if name in self.target_layers:
                x.register_hook(self.save_gradient)
                outputs += [x]
        return outputs, x


class ModelOutputs():
    """ Class for making a forward pass, and getting:
    1. The network output.
    2. Activations from intermeddiate targetted layers.
    3. Gradients from intermeddiate targetted layers. """

    def __init__(self, model, target_layers):
        self.model = model
        self.feature_extractor = FeatureExtractor(self.model, target_layers)

    def get_gradients(self):
        return self.feature_extractor.gradients
    
    def __call__(self, x):

        target_activations, output  = self.feature_extractor(x)
        output = torch.sigmoid(output)
        return target_activations, output
    
def preprocess_image(img):
    means=[0.485, 0.456, 0.406]
    stds=[0.229, 0.224, 0.225]

    preprocessed_img = img.copy()[:, :, ::-1]
    for i in range(3):

        preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i]
        preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i]

    preprocessed_img = \
        np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1)))
    preprocessed_img = torch.from_numpy(preprocessed_img)

    preprocessed_img.unsqueeze_(0)
    input = preprocessed_img.requires_grad_(True)
    return input

def show_cam_on_image(img, mask, file_name):
    
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
    
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
    
        cv2.imwrite(file_name, np.uint8(255 * cam))

def get_cam(model, input_image, class_idx, file_name):

    model.eval()
    output = model(input_image.cuda())
    # output shape is (batch_size, num_classes)
    output = output[:, class_idx]

    # one of the differentiated Tensors does not require grad
    # so we have to manually set it to true
    print(type(input_image))
    input_image.requires_grad_(True)
    grads = torch.autograd.grad(output, input_image, allow_unused= True)
    print(grads)



model = EfficientNet.from_pretrained(model_name,in_channels=1, num_classes=15)
model.to(device)

# get grad-cam for each class
for i in range(15):
    get_cam(model, test_dataset[i][0].unsqueeze(0), i, '/data4/vin_model_weights/' + str(i) + '_class_grad_cam.jpg')

Loaded pretrained weights for efficientnet-b5




<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)




<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)
<class 'torch.Tensor'>
(None,)


In [71]:
print(test_dataset[0])

(tensor([[[-0.9843, -0.9843, -0.9843,  ..., -0.9137, -0.9137, -0.9059],
         [-0.9843, -0.9843, -0.9922,  ..., -0.9137, -0.9137, -0.9137],
         [-0.9843, -0.9922, -0.9922,  ..., -0.9216, -0.9216, -0.9216],
         ...,
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000]]]), tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]]), tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
        

In [36]:
sample = test_dataset[0][0]
print(sample.shape)
uns_sample = sample.unsqueeze(0)
print(uns_sample.shape)

torch.Size([1, 1024, 1024])
torch.Size([1, 1, 1024, 1024])


In [38]:
sample_feature = model.extract_features(uns_sample.cuda())

In [39]:
print(sample_feature.shape)

torch.Size([1, 2048, 32, 32])


In [47]:
sample_output = model(uns_sample.cuda())

In [49]:
print(sample_output)

tensor([[-0.0151, -0.0895, -0.0134, -0.0242, -0.1249,  0.0150, -0.0504, -0.0077,
         -0.1312,  0.1254, -0.0041,  0.0808, -0.1009, -0.0046,  0.0095]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


In [78]:
print(sample_output.shape)

torch.Size([1, 15])


In [50]:
sig_output = torch.sigmoid(sample_output)
print(sig_output)

tensor([[0.4962, 0.4776, 0.4966, 0.4939, 0.4688, 0.5038, 0.4874, 0.4981, 0.4673,
         0.5313, 0.4990, 0.5202, 0.4748, 0.4988, 0.5024]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
