In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split
import torch.nn as nn
import torchvision
import torch.optim as optim


import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import top_k_accuracy_score
from main import Xception, print_test_details, train_transform, test_transform
from main import print_test_details, ResNet, BasicBlock, BottleNeck, train_transform, test_transform #custom function in main.py



import json, os

In [2]:
# HYPERPARAMETERS

train_size = 0.8
test_size=0.2
cuda = 'cuda'
cpu = 'cpu' 
device = cuda if torch.cuda.is_available() else cpu


In [3]:

class myDataset(torch.utils.data.Dataset):

    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x, y = self.data[index]
        if self.transform:
            x = self.transform(x)
        return x, y
   


In [4]:

full_dataset = torchvision.datasets.ImageFolder(
    root='./dataset',
    transform=None
)

train, test = random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

# train_dataset = myDataset(train, transform=train_transform)
test_dataset = myDataset(test, transform=test_transform)

batch_size=64

# train_loader = torch.utils.data.DataLoader(
#     dataset=train_dataset,
#     batch_size=batch_size,
#     shuffle=True, 
# )
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=False   
)   

In [5]:

resnet = ResNet(
    block=BottleNeck,
    layers=[2, 2, 2, 2]
)

resnet.load_state_dict(torch.load("resnet28_selfimplementation_kaggle.pth"))
resnet.to(device)
resnet.eval()


xception = Xception(5)
xception.load_state_dict(torch.load("xception_middle5_5.9M.pth"))
xception.to(device)
xception.eval()

print()






In [6]:
x, y = next(iter(test_loader))


In [7]:
x.shape

torch.Size([1, 10, 3, 299, 299])

In [8]:
@torch.no_grad
def print_test_details(num_samples:int=None, fraction_sampleS:float=1.0):

    assert test_loader.batch_size == 1

    if num_samples is None:
        num_samples = len(test_loader.dataset)*fraction_sampleS

    top_1_correct = 0
    top_5_correct = 0
    # loss = 0
    iter_start = time.time()
    for i, (images, label) in enumerate(test_loader):

        if (i%100 == 0):
            print(f"{i}- time taken={time.time()-iter_start}")
            iter_start = time.time()

        if (i>=num_samples):
            break

        images = images.to(device)
        label = label.to(device)
        
        images = torch.squeeze(images, 0)


        xception_output = xception(images)
        resnet_output = resnet(torchvision.transforms.Resize((224, 224))(images))

        assert resnet_output.shape == xception_output.shape

        output = (xception_output + resnet_output)/2.

        output = torch.mean(output, 0).reshape(1, -1)


        all_possible_labels = np.array(range(256))
            
        cpu_labels = label.cpu().numpy()
        cpu_averaged_outputs = output.cpu().numpy()

        top_1_correct += top_k_accuracy_score(cpu_labels, cpu_averaged_outputs, k=1, labels=all_possible_labels, normalize=False)
        top_5_correct += top_k_accuracy_score(cpu_labels, cpu_averaged_outputs, k=5, labels=all_possible_labels, normalize=False)

    print(f"top-1 accuracy={top_1_correct/num_samples:.4f}, top 5-accuracy={top_5_correct/num_samples:.4f}")



        

In [9]:
print_test_details()

0- time taken=0.03200697898864746
100- time taken=9.25429630279541
200- time taken=8.803548574447632
300- time taken=9.648844718933105
400- time taken=11.251490354537964
500- time taken=10.07700514793396
600- time taken=10.720651865005493
700- time taken=11.954827070236206
800- time taken=12.55631947517395
900- time taken=13.546783924102783
1000- time taken=13.606950998306274
1100- time taken=13.541260957717896
1200- time taken=13.727826833724976
1300- time taken=13.852112054824829
1400- time taken=13.016998529434204
1500- time taken=13.03869366645813
1600- time taken=13.280539751052856
1700- time taken=12.681609392166138
1800- time taken=12.718914031982422
1900- time taken=13.639759540557861
2000- time taken=12.906078815460205
2100- time taken=12.808745384216309
2200- time taken=12.93986988067627
2300- time taken=13.513139009475708
2400- time taken=13.105189323425293
2500- time taken=12.74743390083313
2600- time taken=13.184627294540405
2700- time taken=13.685672044754028
2800- time t