In [1]:
import os
from os import walk

import numpy as np

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchvision.transforms as transforms

from PIL import Image

In [2]:
# Preprocess the images
img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [3]:
# Dataset to store the images (mediumly modified to fit its purpose)
class FaceDataset(Dataset):

    def __init__(self, roots, transform=None):
        self.data = [] #'name0_0001.jpg', 'name0_0002.jpg', ... #save path but not image, save memory
        self.label = [] #0, 1, 2 ...
        self.np_label = []
        self.transform = transform
        self.total_label = 0
        
        # Modified part: to access all images in the selected folders
        for i in range (len(roots)):
            root = roots[i]
            self.classes = os.listdir(root) #get the list fo classes in the directory ['name0', 'name1', ...]

            #get the list of all files in the dataset (load file name)
            for l, c in enumerate(self.classes): #'name0', 'name1' ...
                cls_folder = os.path.join(root, c)
                for f in os.listdir(cls_folder): #'name0_0001.jpg', 'name0_0002.jpg', ...
                    self.data.append(os.path.join(cls_folder, f))
                    self.label.append(self.total_label + l)
            self.total_label = self.label[-1] + 1
        self.np_label = np.array(self.label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # get the image 
        img_path = self.data[idx]
        image = Image.open(img_path) #PIL image

        # perform transformation
        if self.transform is not None:
            image = self.transform(image)
        
        return image
    
    # Modified part: get the minimum index of a label
    def getMinLabelIndex(self, label):
        return self.label.index(label)
    
    def getMaxLabelIndex(self, label):
        same_c_idx = np.where(self.np_label == label)[0]
        return np.max(same_c_idx)
    
    # Modified part: get number of labels in the dataset
    def getTotalLabel(self):
        return self.label[-1] + 1

In [4]:
testset = FaceDataset(["./dataset/test"], transform=img_transform)
testloader = DataLoader(testset, batch_size=2, shuffle=True, num_workers=0)

In [5]:
# Define class that stores the model
class SiameseNet(nn.Module):
    def __init__(self):
        # call super constructor
        super().__init__()
        # fully connected layer
        self.fc1 = nn.Linear(in_features=128*2, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=1024)
        self.fc3 = nn.Linear(in_features=1024, out_features=1)
        
    def forward(self, x1, x2):
        
        x = torch.cat([x1, x2], dim=1) #concatenate 2 feature vector from 2 images (512D + 512D)
        
        # fc layer
        x = F.relu(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = torch.sigmoid(x)
        return x

In [6]:
# Load models
resNet = torch.load("saved_best_resNet34.pt") 
siameseNet = torch.load("saved_best_siameseNet.pt")

# Check GPU availability and use if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

siameseNet = siameseNet.to(device)
siameseNet.eval()
resNet = resNet.to(device)
resNet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
def evaluate(dataset, way):
    classes = dataset.getTotalLabel()
    fold = math.floor(classes/way)
    for i in range (fold):
        fold_features = []
        fold_label = []
        fold_features_truth = []
        correct = 0
        false = 0
        for j in range (way):
            min_index = testset.getMinLabelIndex(i*way+j)
            max_index = testset.getMaxLabelIndex(i*way+j)
            for k in range (min_index, max_index + 1):
                img = testset[k]
                img = img.unsqueeze(0)
                with torch.no_grad():
                    img = resNet(img)
                if k == min_index:
                    fold_features_truth.append(img)
                else:
                    fold_features.append(img)
                    fold_label.append(j)
                    
        for j in range (len(fold_features)):
            similarity = []
            cur_feature = fold_features[j]
            # The feature contains the highest similarity has the same label
            for k in range (len(fold_features_truth)):
                similarity.append(siameseNet(cur_feature, fold_features_truth[k]))
            # label with max similarity is the predicted label
            if (fold_label[j] == similarity.index(max(similarity))):
                correct += 1
            else:
                false += 1
            
        print("Accuracy of fold " + str(i + 1) + " = " + str(correct/(correct + false)) )

In [8]:
# evaluate using 20 way one-shot task (only the test set)
evaluate(testset, 20)

Accuracy of fold 1 = 0.5567117585848075


In [9]:
# evaluate using 10 way one-shot task (only the test set)
evaluate(testset, 10)

Accuracy of fold 1 = 0.654421768707483
Accuracy of fold 2 = 0.48672566371681414


In [10]:
# evaluate using 5 way one-shot task (only the test set)
evaluate(testset, 5)

Accuracy of fold 1 = 0.6666666666666666
Accuracy of fold 2 = 0.8857142857142857
Accuracy of fold 3 = 0.5681818181818182
Accuracy of fold 4 = 0.7659574468085106
