In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

In [None]:
class Ensemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(Ensemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        #self.modelA.classifier[6]=nn.Identity()
        #self.modelB.classifier[6]=nn.Identity()        
        #self.classifier=nn.Linear(4096+4096,2)
        
    def forward(self, a,b):
        x1 = self.modelA(a.clone())  
        x1 = x1.view(x1.size(0), -1)
        x2 = self.modelB(b.clone())
        x2 = x2.view(x2.size(0), -1)
        
        #x=torch.cat((x1,x2),dim=1)
        #x=self.classifier(torch.nn.functional.relu(x))
        #x1=torch.nn.functional.softmax(x1)
        #x2=torch.nn.functional.softmax(x2)
        x=((x1)+(x2))/2

        return x

In [None]:
test_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_dir_intra = './TCIA_splitted_intra2'
test_datasets_intra = datasets.ImageFolder(os.path.join(data_dir_intra, 'test'),
                                          test_transforms)
testloaders_intra = torch.utils.data.DataLoader(test_datasets_intra, batch_size=4,
                                             shuffle=False, num_workers=0)

data_dir_peri = './TCIA_splitted_peri2'
test_datasets_peri = datasets.ImageFolder(os.path.join(data_dir_peri, 'test'),
                                          test_transforms)
testloaders_peri = torch.utils.data.DataLoader(test_datasets_peri, batch_size=4,
                                             shuffle=False, num_workers=0)

data_dir_texture = './TCIA_splitted_texture'
test_datasets_texture = datasets.ImageFolder(os.path.join(data_dir_texture, 'test'),
                                          test_transforms)
testloaders_texture = torch.utils.data.DataLoader(test_datasets_texture, batch_size=4,
                                             shuffle=False, num_workers=0)

In [None]:
alexnet4_intra= torch.load('./models/alexnet_intra4.pt')
alexnet4_intra.eval()
alexnet4_peri= torch.load('./models/alexnet_peri4.pt')
alexnet4_peri.eval()
alexnet4_peri= torch.load('./models/alexnet_peri4.pt')
alexnet4_peri.eval()

In [None]:
for param in alexnet4_intra.parameters():
    param.requires_grad_(False)
for param in alexnet4_peri.parameters():
    param.requires_grad_(False)

correct = 0
total = 0

with torch.no_grad():
    for data in testloaders_intra:
        img,labels=data
        print('***********')
        print(labels)
        output=alexnet4_intra(img)
        _, predicted = torch.max(output.data, 1)
        print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Testset Accuracy(mean): %f %%' % (100 * correct / total))


In [None]:
ensemble=Ensemble(alexnet4_intra,alexnet4_peri)

In [None]:
####Accuracy####
correct = 0
total = 0


for intra,peri in zip(testloaders_intra,testloaders_peri):
    img_intra,label_intra=intra
    img_peri,label_peri=peri
    print('***********')
    #print('labels')
    print(label_intra)
    print(label_peri)
    output=ensemble(img_intra,img_peri)
    _, predicted = torch.max(output.data, 1)
    #print(x1)
    #print(x2)
    #print(output)
    print(predicted)
    total += label_intra.size(0)
    correct += (predicted == label_intra).sum().item()

print('Testset Accuracy(mean): %f %%' % (100 * correct / total))

In [None]:
def getConfusionMatrix(model):
    model.eval()
    confusion_matrix=np.zeros((2,2),dtype=int)
    num_images=len(test_datasets_intra)
    
    with torch.no_grad():
        for intra, peri in zip(testloaders_intra,testloaders_peri):
            #inputs = inputs.to(device)
            #labels = labels.to(device)
            img_intra,label_intra=intra
            img_peri,label_peri=peri
            outputs = ensemble(img_intra,img_peri)
            _, preds = torch.max(outputs, 1)
            
            for j in range(len(preds)):
                if preds[j]==1 and label_intra[j]==1:
                    term='TP'
                    confusion_matrix[0][0]+=1
                elif preds[j]==1 and label_intra[j]==0:
                    term='FP'
                    confusion_matrix[1][0]+=1
                elif preds[j]==0 and label_intra[j]==1:
                    term='FN'
                    confusion_matrix[0][1]+=1
                elif preds[j]==0 and label_intra[j]==0:
                    term='TN'
                    confusion_matrix[1][1]+=1

        print('Confusion Matrix: ')
        print(confusion_matrix)
        print()
        print('Sensitivity: ', 100*confusion_matrix[0][0]/(confusion_matrix[0][0]+confusion_matrix[0][1]))
        print('Specificity: ', 100*confusion_matrix[1][1]/(confusion_matrix[1][1]+confusion_matrix[1][0]))
        print('PPV: ', 100*confusion_matrix[0][0]/(confusion_matrix[0][0]+confusion_matrix[1][0]))
        print('NPV: ', 100*confusion_matrix[1][1]/(confusion_matrix[1][1]+confusion_matrix[0][1]))
        
        #return confusion_matrix

In [None]:
getConfusionMatrix(ensemble)