In [1]:
from torch.utils.data.dataset import Dataset
import os
import cv2
from pathlib import Path
from torchvision.transforms import transforms
from PIL import Image
import auxiliary
import model
import torch
from d2l import torch as d2l
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import torchmetrics
from ignite.metrics import Precision
from ignite.metrics import Accuracy
from ignite.metrics import Recall



In [2]:
class IndividualSeedDataset(Dataset):
    def __init__(self, df, transform=None, target_transform=None):
        self.df = df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        # it can read either from annotated file with bounding box coordinates on an image of multiple seeds or from csv
        # file of cropped seed images
        filenames = item[0].rsplit('.')
        if not os.path.exists(item[0]):
            # check uppercase JPG or lowercase jpg
            filename = filenames[0] + '.' + filenames[1].upper()
            if not os.path.exists(filename):
                filename = filenames[0] + '.' + filenames[1].lower()
                if not os.path.exists(filename):
                    raise Exception('File does not exist!')
        else:
            filename = item[0]
        image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
        if len(item) > 2: # read from bounding box
            label = item[5]
            cropped_image = image[item[2]:item[4], item[1]:item[3]]  # y_min:y_max, x_min:x_max,
        else: # read from cropped seed image directly
            label = item[1]
            cropped_image = image

        if label == 'GOOD':
            label = 1
        elif label == 'BAD':
            label = 0
        else:
            raise Exception("Unrecognised seed label - choose from either 'GOOD' or 'BAD'")
        if not os.path.exists(filename):
            print(filename)
            raise Exception('The seed image does not exist!')
        # normalise the cropped seed image between 0 and 1
        #cropped_image = cv2.normalize(cropped_image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        cropped_image = cv2.resize(cropped_image, (224,224))
        cropped_image = transforms.ToTensor()(cropped_image).float()
        cropped_image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(cropped_image).float()
        #### for debug only #####
        #print(image.shape)
        #print(idx, ': \t', cropped_image.shape, '\t', item[2], '\t', item[4])
        #########################

        # if any transformation is needed, e.g., to resize the image
        if self.transform:
            cropped_image = self.transform(cropped_image)
        if self.target_transform:
            label = self.target_transform(label)

        return cropped_image, label


In [3]:
devices = d2l.try_all_gpus()

In [4]:
test_data_dir = 'data/test/'
good_seed = 'GoodSeed/'
bad_seed = 'BadSeed/'
csvfile = r"./data/csv/LightBox_annotation.CSV"
df_batch2 = pd.read_csv(csvfile)
csvfile = r"./data/csv/NormalRoomLight_annotation.csv"
df_batch3 = pd.read_csv(csvfile)

In [5]:
class MyDataset(Dataset):
    def __init__(self, good_seed_root, bad_seed_root, transform=None):
        self.transform = transforms.Compose([
            transforms.ToTensor()      
        ])
        self.transform = transforms.Compose([
            transforms.CenterCrop(180), 
            transforms.Resize(224), 
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
        ])
        good_seed_images_path = Path(good_seed_root)
        bad_seed_images_path = Path(bad_seed_root)
        images_list = list(good_seed_images_path.glob('*.png')) + list(bad_seed_images_path.glob('*.png'))
        images_list_str = [ str(x) for x in images_list ]
        self.images = images_list_str

    def __getitem__(self, item):
        image_path = self.images[item]
        image = Image.open(image_path) 
        image = self.transform(image)  
        label = 1 if 'good' in image_path.split('\\')[-1] else 0 
        return image, label

    def __len__(self):
        return len(self.images)

In [6]:
good_seed_test_data_file = Path(test_data_dir, good_seed)
bad_seed_test_data_file = Path(test_data_dir, bad_seed)
test_data_batch1 = MyDataset(good_seed_test_data_file, bad_seed_test_data_file)
test_data_batch2 = IndividualSeedDataset(df_batch2)
test_data_batch3 = IndividualSeedDataset(df_batch3)

test_data_batch1_dataloader = torch.utils.data.DataLoader(test_data_batch1, batch_size=128, shuffle=True)
test_data_batch2_dataloader = torch.utils.data.DataLoader(test_data_batch2, batch_size=128, shuffle=True)
test_data_batch3_dataloader = torch.utils.data.DataLoader(test_data_batch3, batch_size=128, shuffle=True)

In [7]:
pthfile_Dino_resnet = './models/trained_models/net_params_604345505_lr0.0001_epoch60_Dino_Resnet50.pth'
pthfile_resnet = './models/trained_models/net_params_704177844_lr0.0001_epoch60_Resnet50.pth'
pthfile_Dino_Vit = './models/trained_models/net_params_643813538_lr0.0001_epoch40_Dino_vit.pth'
pthfile_Vit = './models/trained_models/net_params_391561151_lr0.0001_epoch40_vit.pth'

In [8]:
def evaluate_net(pthfile, test_data, net, devices):
    sum = 0
    if(net == 'Dino_resnet'):
        net = model.get_Dino_net(devices)
    elif(net == 'ResNet'):
        net = model.get_ResNet_net(devices)
    elif(net == 'Dino_vit'):
        net = model.get_Dino_Vit_net(devices)
    elif(net == 'vit'):
        net = model.get_Vit_net(devices)
    net_data = torch.load(pthfile)
    net.load_state_dict(net_data)
    net.eval()
    net.to(devices[0])
    for i in range(len(test_data)):
        x = test_data[i][0].unsqueeze(0).to(devices[0])
        y_hat = net(x).argmax(axis=1)
        if(y_hat == test_data[i][1]):
            sum = sum + 1
    print("test accuracy is ", sum/len(test_data))

In [9]:
def evaluate_net1(pthfile, dataloader, net, devices):
    test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
    test_recall = torchmetrics.Recall(task="multiclass",average='macro', num_classes=2)
    test_precision = torchmetrics.Precision(task="multiclass",average='macro', num_classes=2)
    test_auc = torchmetrics.AUROC(task="multiclass",average="macro", num_classes=2)

    if(net == 'Dino_resnet'):
        net = model.get_Dino_net(devices)
    elif(net == 'ResNet'):
        net = model.get_ResNet_net(devices)
    elif(net == 'Dino_vit'):
        net = model.get_Dino_Vit_net(devices)
    elif(net == 'vit'):
        net = model.get_Vit_net(devices)
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    net_data = torch.load(pthfile)
    net.load_state_dict(net_data)
    net.eval()
    net.to('cpu')

    with torch.no_grad():
        for X, y in dataloader:
            pred = net(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            # 一个batch进行计算迭代
            test_acc(pred.argmax(1), y)
            test_auc.update(pred, y)
            test_recall(pred.argmax(1), y)
            test_precision(pred.argmax(1), y)
    
    correct /= size

    # 计算一个epoch的accuray、recall、precision、AUC
    total_acc = test_acc.compute()
    total_recall = test_recall.compute()
    total_precision = test_precision.compute()
    total_auc = test_auc.compute()
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, "
          f"torch metrics acc: {(100 * total_acc):>0.1f}%\n")
    print("recall of every test dataset class: ", total_recall)
    print("precision of every test dataset class: ", total_precision)
    print("auc:", total_auc.item())

    # 清空计算对象
    test_precision.reset()
    test_acc.reset()
    test_recall.reset()
    test_auc.reset()

In [10]:
def evaluate_f1(pthfile, test_data, net, devices):
    print("f1:")
    y_true = []
    y_pred = []
    if(net == 'Dino_resnet'):
        net = model.get_Dino_net(devices)
    elif(net == 'ResNet'):
        net = model.get_ResNet_net(devices)
    elif(net == 'Dino_vit'):
        net = model.get_Dino_Vit_net(devices)
    elif(net == 'vit'):
        net = model.get_Vit_net(devices)
    net_data = torch.load(pthfile)
    net.load_state_dict(net_data)
    net.eval()
    net.to(devices[0])
    for i in range(len(test_data)):
        y_true.append(test_data[i][1])
        x = test_data[i][0].unsqueeze(0).to(devices[0])
        y_pred.append(int(net(x).argmax(axis=1)))
    f1 = f1_score(y_true, y_pred)
    print(f1)

In [24]:
net = 'Dino_resnet'
print('Dino_resnet_batch1:')
evaluate_net(pthfile_Dino_resnet, test_data_batch1, net, devices)

Dino_resnet_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.9451371571072319


In [13]:
net = 'ResNet'
print('ResNet_batch1:')
evaluate_net(pthfile_resnet, test_data_batch1, net)

ResNet_batch1:
test accuracy is  0.9102244389027432


In [14]:
net = 'Dino_vit'
print('Dino_vit_batch1:')
evaluate_net(pthfile_Dino_Vit, test_data_batch1, net)

Dino_vit_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.9675810473815462


In [15]:
net = 'vit'
print('vit_batch1:')
evaluate_net(pthfile_Vit, test_data_batch1, net)

vit_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


test accuracy is  0.9177057356608479


In [16]:
net = 'Dino_resnet'
print('Dino_resnet_batch2:')
evaluate_net(pthfile_Dino_resnet, test_data_batch2, net)

Dino_resnet_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.6669449081803005


In [17]:
net = 'ResNet'
print('ResNet_batch2:')
evaluate_net(pthfile_resnet, test_data_batch2, net)

ResNet_batch2:
test accuracy is  0.7045075125208681


In [18]:
net = 'Dino_vit'
print('vit_batch2:')
evaluate_net(pthfile_Dino_Vit, test_data_batch2, net)

vit_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.7212020033388982


In [19]:
net = 'vit'
print('vit_batch2:')
evaluate_net(pthfile_Vit, test_data_batch2, net)

vit_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


test accuracy is  0.6928213689482471


In [20]:
net = 'Dino_resnet'
print('Dino_resnet_batch3:')
evaluate_net(pthfile_Dino_resnet, test_data_batch3, net)

Dino_resnet_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.5644444444444444


In [21]:
net = 'ResNet'
print('ResNet_batch3:')
evaluate_net(pthfile_resnet, test_data_batch3, net)

ResNet_batch3:
test accuracy is  0.7811111111111111


In [22]:
net = 'Dino_vit'
print('Dino_vit_batch3:')
evaluate_net(pthfile_Dino_Vit, test_data_batch3, net)

Dino_vit_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


test accuracy is  0.8922222222222222


In [23]:
net = 'vit'
print('vit_batch3:')
evaluate_net(pthfile_Vit, test_data_batch3, net)

vit_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


test accuracy is  0.8211111111111111


In [36]:
net = 'Dino_resnet'
print('Dino_resnet_batch1:')
evaluate_f1(pthfile_Dino_resnet, test_data_batch1, net)

Dino_resnet_batch1:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main
Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.9444444444444444


In [15]:
net = 'ResNet'
print('ResNet_batch1:')
evaluate_f1(pthfile_resnet, test_data_batch1, net)

ResNet_batch1:
f1:
0.9072164948453608


In [16]:
net = 'Dino_vit'
print('Dino_vit_batch1:')
evaluate_f1(pthfile_Dino_Vit, test_data_batch1, net)

Dino_vit_batch1:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.9675810473815462


In [17]:
net = 'vit'
print('vit_batch1:')
evaluate_f1(pthfile_Vit, test_data_batch1, net)

vit_batch1:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


0.9129287598944591


In [18]:
net = 'Dino_resnet'
print('Dino_resnet_batch2:')
evaluate_f1(pthfile_Dino_resnet, test_data_batch2, net)
net = 'ResNet'
print('ResNet_batch2:')
evaluate_f1(pthfile_resnet, test_data_batch2, net)
net = 'Dino_vit'
print('Dino_vit_batch2:')
evaluate_f1(pthfile_Dino_Vit, test_data_batch2, net)
net = 'vit'
print('vit_batch2:')
evaluate_f1(pthfile_Vit, test_data_batch2, net)

Dino_resnet_batch2:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.6586826347305389
ResNet_batch2:
f1:
0.7500000000000001
Dino_vit_batch2:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.6812977099236641
vit_batch2:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


0.7186544342507646


In [19]:
net = 'Dino_resnet'
print('Dino_resnet_batch3:')
evaluate_f1(pthfile_Dino_resnet, test_data_batch3, net)
net = 'ResNet'
print('ResNet_batch3：')
evaluate_f1(pthfile_resnet, test_data_batch3, net)
net = 'Dino_vit'
print('Dino_vit_batch3:')
evaluate_f1(pthfile_Dino_Vit, test_data_batch3, net)
net = 'vit'
print('vit_batch3:')
evaluate_f1(pthfile_Vit, test_data_batch3, net)

Dino_resnet_batch3:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.6965944272445821
ResNet_batch3：
f1:
0.770663562281723
Dino_vit_batch3:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


0.8951351351351352
vit_batch3:
f1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


0.8432327166504381


In [11]:

net = 'Dino_resnet'
print('Dino_resnet_batch1:')
evaluate_net1(pthfile_Dino_resnet, test_data_batch1_dataloader, net, devices)


print("***************************************************************")


net = 'ResNet'
print('ResNet_batch1:')
evaluate_net1(pthfile_resnet, test_data_batch1_dataloader, net, devices)


print("***************************************************************")


net = 'Dino_vit'
print('Dino_vit_batch1:')
evaluate_net1(pthfile_Dino_Vit, test_data_batch1_dataloader, net, devices)


print("***************************************************************")


net = 'vit'
print('vit_batch1:')
evaluate_net1(pthfile_Vit, test_data_batch1_dataloader, net, devices)


print("***************************************************************")


net = 'Dino_resnet'
print('Dino_resnet_batch2:')
evaluate_net1(pthfile_Dino_resnet, test_data_batch2_dataloader, net, devices)


print("***************************************************************")


net = 'ResNet'
print('ResNet_batch2:')
evaluate_net1(pthfile_resnet, test_data_batch2_dataloader, net, devices)


print("***************************************************************")


net = 'Dino_vit'
print('Dino_vit_batch2:')
evaluate_net1(pthfile_Dino_Vit, test_data_batch2_dataloader, net, devices)


print("***************************************************************")


net = 'vit'
print('vit_batch2:')
evaluate_net1(pthfile_Vit, test_data_batch2_dataloader, net, devices)


print("***************************************************************")


net = 'Dino_resnet'
print('Dino_resnet_batch3:')
evaluate_net1(pthfile_Dino_resnet, test_data_batch3_dataloader, net, devices)


print("***************************************************************")


net = 'ResNet'
print('ResNet_batch3:')
evaluate_net1(pthfile_resnet, test_data_batch3_dataloader, net, devices)


print("***************************************************************")


net = 'Dino_vit'
print('Dino_vit_batch3:')
evaluate_net1(pthfile_Dino_Vit, test_data_batch3_dataloader, net, devices)


print("***************************************************************")


net = 'vit'
print('vit_batch3:')
evaluate_net1(pthfile_Vit, test_data_batch3_dataloader, net, devices)

Dino_resnet_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 94.5%, torch metrics acc: 94.5%

recall of every test dataset class:  tensor(0.9452)
precision of every test dataset class:  tensor(0.9455)
auc: 0.9877861142158508
***************************************************************
ResNet_batch1:
Test Error: 
 Accuracy: 91.0%, torch metrics acc: 91.0%

recall of every test dataset class:  tensor(0.9103)
precision of every test dataset class:  tensor(0.9122)
auc: 0.9750746488571167
***************************************************************
Dino_vit_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 96.8%, torch metrics acc: 96.8%

recall of every test dataset class:  tensor(0.9676)
precision of every test dataset class:  tensor(0.9676)
auc: 0.9956467151641846
***************************************************************
vit_batch1:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


Test Error: 
 Accuracy: 91.8%, torch metrics acc: 91.8%

recall of every test dataset class:  tensor(0.9178)
precision of every test dataset class:  tensor(0.9232)
auc: 0.9788308143615723
***************************************************************
Dino_resnet_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 66.7%, torch metrics acc: 66.7%

recall of every test dataset class:  tensor(0.6673)
precision of every test dataset class:  tensor(0.6678)
auc: 0.7259125709533691
***************************************************************
ResNet_batch2:
Test Error: 
 Accuracy: 70.5%, torch metrics acc: 70.5%

recall of every test dataset class:  tensor(0.7028)
precision of every test dataset class:  tensor(0.7318)
auc: 0.7860298156738281
***************************************************************
vit_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 72.1%, torch metrics acc: 72.1%

recall of every test dataset class:  tensor(0.7225)
precision of every test dataset class:  tensor(0.7387)
auc: 0.8398735523223877
***************************************************************
vit_batch2:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


Test Error: 
 Accuracy: 69.3%, torch metrics acc: 69.3%

recall of every test dataset class:  tensor(0.6920)
precision of every test dataset class:  tensor(0.6979)
auc: 0.747533917427063
***************************************************************
Dino_resnet_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 56.4%, torch metrics acc: 56.4%

recall of every test dataset class:  tensor(0.5644)
precision of every test dataset class:  tensor(0.7672)
auc: 0.9254716634750366
***************************************************************
ResNet_batch3:
Test Error: 
 Accuracy: 78.1%, torch metrics acc: 78.1%

recall of every test dataset class:  tensor(0.7811)
precision of every test dataset class:  tensor(0.7835)
auc: 0.8616493940353394
***************************************************************
Dino_vit_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_dino_main


Test Error: 
 Accuracy: 89.2%, torch metrics acc: 89.2%

recall of every test dataset class:  tensor(0.8922)
precision of every test dataset class:  tensor(0.8934)
auc: 0.9674518704414368
***************************************************************
vit_batch3:


Using cache found in C:\Users\Ivan/.cache\torch\hub\facebookresearch_deit_main


Test Error: 
 Accuracy: 82.1%, torch metrics acc: 82.1%

recall of every test dataset class:  tensor(0.8211)
precision of every test dataset class:  tensor(0.8489)
auc: 0.9272839426994324
