In [1]:
import torch
import torchvision
import torch.nn as nn

import matplotlib.pyplot as plt
import os
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn_model = torchvision.models.resnet34(weights=True)
num_features = cnn_model.fc.in_features
cnn_model.fc = nn.Linear(num_features, 8)
cnn_model.load_state_dict(torch.load('./cnn.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn_model.parameters(), lr=0.001, momentum=0.9)



test_dir = './dataset/AI-DATASET/data/test'

transforms_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, files, root, mode='train', transform=None):
        self.files = files
        self.root = root
        self.mode = mode
        self.transform=transform
        
        if 'beverage_can' in files[0]:
            self.label = 0
        elif 'beverage_bottle' in files[0]:
            self.label = 1
        elif 'noodle_bag' in files[0]:
            self.label = 2
        elif 'noodle_cup' in files[0]:
            self.label = 3
        elif 'beverage_milk' in files[0]:
            self.label = 4
        elif 'icecream' in files[0]:
            self.label = 5
        else:
            self.label = 6


    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, index):
        img = plt.Image.open(os.path.join(self.root, self.files[index]))
        
        if self.transform:
            img = self.transform(img)
        if self.mode == 'train':
            return img, np.array([self.label])
        else:
            return img, self.files[index]
        
test_files = [f'{i}.jpg' for i in range(1, 20315+1)]
test_datasets = CustomDataset(files=test_files, root=test_dir, mode='test', transform=transforms_test)
test_dataloader = torch.utils.data.DataLoader(
    test_datasets, batch_size=32, shuffle=False)



In [7]:
imgs, files = iter(test_dataloader).next()