In [1]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

In [56]:
class Ensemble(nn.Module):
    def __init__(self, modelA, modelB,modelC):
        super(Ensemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        self.modelC = modelC
        
    def forward(self, a,b,c,label):
        x1 = self.modelA(a.clone())  
        x1 = x1.view(x1.size(0), -1)
        x2 = self.modelB(a.clone())
        x2 = x2.view(x2.size(0), -1)
        x3 = self.modelC(c.clone())
        x3 = x3.view(x3.size(0), -1)
       
        ###hard voting###
        _,x1_pred=torch.max(x1.data,1)
        _,x2_pred=torch.max(x2.data,1)
        _,x3_pred=torch.max(x3.data,1)

        x=x1_pred+x2_pred+x3_pred
        for i in range(0,len(label)):
            if x[i]>=2:
                x[i]=1
            else:
                x[i]=0
                
        '''
        print('- X1')
        print(x1)
        print(x1_pred)
        print('- X2')
        print(x2)
        print(x2_pred)
        print('- X3')
        print(x3)
        print(x3_pred)
        print('- Result')
        print(x)
        '''

        return x

In [44]:
def getConfusionMatrix(model):
    model.eval()
    confusion_matrix=np.zeros((2,2),dtype=int)
    num_images=len(test_datasets_intra)
    
    with torch.no_grad():
        for intra, peri, texture in zip(testloaders_intra,testloaders_peri,testloaders_texture):
            #inputs = inputs.to(device)
            #labels = labels.to(device)
            img_intra,label_intra=intra
            img_peri,label_peri=peri
            img_texture,label_texture=texture
            preds = model(img_intra,img_peri,img_texture, label_intra)
            
            for j in range(len(preds)):
                if preds[j]==1 and label_intra[j]==1:
                    term='TP'
                    confusion_matrix[0][0]+=1
                elif preds[j]==1 and label_intra[j]==0:
                    term='FP'
                    confusion_matrix[1][0]+=1
                elif preds[j]==0 and label_intra[j]==1:
                    term='FN'
                    confusion_matrix[0][1]+=1
                elif preds[j]==0 and label_intra[j]==0:
                    term='TN'
                    confusion_matrix[1][1]+=1

        print('Confusion Matrix: ')
        print(confusion_matrix)
        print()
        print('Sensitivity: ', 100*confusion_matrix[0][0]/(confusion_matrix[0][0]+confusion_matrix[0][1]))
        print('Specificity: ', 100*confusion_matrix[1][1]/(confusion_matrix[1][1]+confusion_matrix[1][0]))
        print('PPV: ', 100*confusion_matrix[0][0]/(confusion_matrix[0][0]+confusion_matrix[1][0]))
        print('NPV: ', 100*confusion_matrix[1][1]/(confusion_matrix[1][1]+confusion_matrix[0][1]))
        
        return confusion_matrix

In [4]:
test_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_dir_intra = './TCIA_splitted_intra10'
test_datasets_intra = datasets.ImageFolder(os.path.join(data_dir_intra, 'test'),
                                          test_transforms)
testloaders_intra = torch.utils.data.DataLoader(test_datasets_intra, batch_size=4,
                                             shuffle=False, num_workers=0)

data_dir_peri = './TCIA_splitted_peri10'
test_datasets_peri = datasets.ImageFolder(os.path.join(data_dir_peri, 'test'),
                                          test_transforms)
testloaders_peri = torch.utils.data.DataLoader(test_datasets_peri, batch_size=4,
                                             shuffle=False, num_workers=0)

data_dir_texture = './TCIA_splitted_texture10'
test_datasets_texture = datasets.ImageFolder(os.path.join(data_dir_texture, 'test'),
                                          test_transforms)
testloaders_texture = torch.utils.data.DataLoader(test_datasets_texture, batch_size=4,
                                             shuffle=False, num_workers=0)


In [51]:
print(len(test_datasets_intra))
print(len(test_datasets_peri))
print(len(test_datasets_texture))

58
58
58


In [52]:
alexnet4_intra= torch.load('./models/alexnet_intra10.pt')
alexnet4_intra.eval()
alexnet4_peri= torch.load('./models/alexnet_peri10.pt')
alexnet4_peri.eval()
alexnet4_texture= torch.load('./models/alexnet_texture10.pt')
alexnet4_texture.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [53]:
for param in alexnet4_intra.parameters():
    param.requires_grad_(False)
for param in alexnet4_peri.parameters():
    param.requires_grad_(False)
for param in alexnet4_texture.parameters():
    param.requires_grad_(False)

correct = 0
total = 0

with torch.no_grad():
    for data in testloaders_peri:
        img,labels=data
        print('***********')
        print(labels)
        output=alexnet4_peri(img)
        _, predicted = torch.max(output.data, 1)
        print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Testset Accuracy(mean): %f %%' % (100 * correct / total))


In [57]:
ensemble=Ensemble(alexnet4_intra,alexnet4_peri,alexnet4_texture)

In [58]:
####Accuracy####
correct = 0
total = 0


for intra,peri,texture in zip(testloaders_intra,testloaders_peri,testloaders_texture):
    img_intra,label_intra=intra
    img_peri,label_peri=peri
    img_texture,label_texture=texture

    print('***********')
    print('- Labels')
    print(label_intra)
    
    predicted=ensemble(img_intra,img_peri,img_texture,label_texture)
    
    print('- Predicted')
    print(predicted)
    
    total += label_intra.size(0)
    correct += (predicted == label_intra).sum().item()

print('Testset Accuracy(mean): %f %%' % (100 * correct / total))

***********
Labels
tensor([0, 0, 0, 0])
predicted
tensor([1, 1, 0, 1])
***********
Labels
tensor([0, 0, 0, 0])
predicted
tensor([0, 1, 1, 1])
***********
Labels
tensor([0, 0, 0, 0])
predicted
tensor([1, 1, 0, 0])
***********
Labels
tensor([0, 0, 0, 0])
predicted
tensor([0, 1, 1, 1])
***********
Labels
tensor([0, 0, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 1, 1, 1])
***********
Labels
tensor([1, 1, 1, 1])
predicted
tensor([1, 0, 1, 1])
******

In [59]:
getConfusionMatrix(ensemble)

Confusion Matrix: 
[[39  1]
 [13  5]]

Sensitivity:  97.5
Specificity:  27.77777777777778
PPV:  75.0
NPV:  83.33333333333333
