In [10]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms
import os
import numpy as np
from tqdm import tqdm
from PIL import Image

In [11]:
class FoodDataset(Dataset):
    def __init__(self, file, transform=None, mode='train'):
        self.transforms = transform
        self.mode = mode
        with open(file, 'r') as f:
            self.image_list = f.readlines()

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        label = None
        if self.mode == 'train':
            image, label = self.image_list[index].split('\n')[0].split('\t')
            label = int(label)
        else:
            image = self.image_list[index].split('\n')[0]
        image = Image.open(image).convert('RGB')
        image = self.transforms(image)
        if self.mode == 'train':
            return image, label
        else:
            return image

In [12]:
transforms_test = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

In [13]:
def evaluate(prediction, ground_truth):
    num_correct = (np.array(prediction) == np.array(ground_truth)).sum()
    return num_correct / len(prediction)

In [14]:
test_ds = FoodDataset('/media/ntu/volume2/home/s121md302_07/food/data/test.txt', transform=transforms_test)

test_dl = DataLoader(test_ds, batch_size=32, shuffle=True, num_workers=8)

In [16]:
num_classes = 5
train_model = models.resnet34(pretrained=True)
train_model.fc = nn.Linear(512, num_classes)
train_model.load_state_dict(torch.load('/media/ntu/volume2/home/s121md302_07/food/checkpoint_resnet34/resnet34_50.pth'))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [20]:
train_model.to(device)
train_model.eval()
output_list = []
ground_truth_list = []
for img, label in tqdm(test_dl):
    with torch.no_grad():
        img = img.to(device)
        lbl = label.to(device)

        output= train_model(img)


        prediction = torch.argmax(output, dim=1)
        output_list.extend(prediction.detach().cpu())
        ground_truth_list.extend(label)
accuracy = evaluate(output_list, ground_truth_list)
print(f'Accuracy: {accuracy}')

100%|██████████| 32/32 [00:01<00:00, 21.57it/s]

Accuracy: 0.892



