# HW3 Image Classification
**因需跑超過colab限制時間，因此將此份code轉成py檔，放到cluster上跑。**
## We strongly recommend that you run with Kaggle for this homework
https://www.kaggle.com/c/ml2022spring-hw3b/code?competitionId=34954&sortBy=dateCreated

# Get Data
Notes: if the links are dead, you can download the data directly from Kaggle and upload it to the workspace, or you can use the Kaggle API to directly download the data into colab.

In [1]:
#! wget https://www.dropbox.com/s/6l2vcvxl54b0b6w/food11.zip
! wget -O food11.zip "https://github.com/virginiakm1988/ML2022-Spring/blob/main/HW03/food11.zip?raw=true"

--2022-03-24 08:14:32--  https://github.com/virginiakm1988/ML2022-Spring/blob/main/HW03/food11.zip?raw=true
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/virginiakm1988/ML2022-Spring/raw/main/HW03/food11.zip [following]
--2022-03-24 08:14:32--  https://github.com/virginiakm1988/ML2022-Spring/raw/main/HW03/food11.zip
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://media.githubusercontent.com/media/virginiakm1988/ML2022-Spring/main/HW03/food11.zip [following]
--2022-03-24 08:14:32--  https://media.githubusercontent.com/media/virginiakm1988/ML2022-Spring/main/HW03/food11.zip
Resolving media.githubusercontent.com (media.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to media.githubusercontent.com (media.githubusercontent.com)|185

In [2]:
%%capture
! unzip food11.zip

# Training

In [3]:
import numpy as np
import pandas as pd
from glob import glob
import torch
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from pathlib import Path
import torch.nn as nn
from tqdm import tqdm
from typing import List
from torch.utils.tensorboard import SummaryWriter
import random

In [4]:
# make the training be deterministic
my_seed = 6666
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(my_seed)
random.seed(my_seed)
torch.manual_seed(my_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(my_seed)

## Dataset

In [5]:
N_REPEAT = 2

class FoodDataset(Dataset):

    def __init__(self, path: str, transform=None, testing=False, n_repeat=1, clean=False):
        super().__init__()
        if clean:
            self.files = [
                    f for f in sorted(glob(os.path.join(path, '*.jpg')))
                if int(
                    Path(f).stem.split('_')[0]
                ) not in [1, 10]
            ]
        else:
            self.files = sorted(glob(os.path.join(path, '*.jpg')))
        assert len(self.files) > 0
        if n_repeat > 1:
            self.files = [f for f in self.files for _ in range(n_repeat)]
        self.transform = transform
        self.testing = testing

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_name = self.files[idx]
        if self.transform:
            image = self.transform(Image.open(file_name))
        else:
            image = Image.open(file_name)

        if self.testing:
            label = -1
        else:
            tags = Path(file_name).stem.split('_')
            assert len(tags) == 2
            label = int(tags[0])

        return (image, label)


class MyDataset(Dataset):

    def __init__(self, files: List[Path], n_repeat: int = 0):
        super().__init__()
        assert len(files) > 0
        self.files = files
        if n_repeat > 0:
            self.files = self.files * n_repeat
        self.transform = None

    def __len__(self):
        return len(self.files)

    def set_transform(self, transform):
        self.transform = transform

    def __getitem__(self, idx):
        file_name = self.files[idx]
        if self.transform:
            image = self.transform(Image.open(file_name))
        else:
            image = Image.open(file_name)

        tags = Path(file_name).stem.split('_')
        assert len(tags) == 2
        label = int(tags[0])

        return (image, label)

In [6]:
def random_split(ratio=0.8):
    files1 = list(Path('./food11/training').glob('*.jpg'))
    files2 = list(Path('./food11/validation').glob('*.jpg'))
    files = files1+files2
    random.shuffle(files)
    n_train = int(len(files) * ratio)
    return {
        'train': MyDataset(files[:n_train], n_repeat=N_REPEAT),
        'valid': MyDataset(files[n_train:])
    }

## Transform

In [7]:
tf = {
    'train': transforms.Compose([
        #  transforms.RandomCrop(
        #      (512, 512),
        #      pad_if_needed=True
        #  ),
        #  transforms.RandomCrop(
        #      (640, 640),
        #      pad_if_needed=True
        #  ),
        #  transforms.CenterCrop((640, 640)),
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(
            degrees=(-20, 20),
            translate=(0.1, 0.3),
            scale=(0.5, 0.75)
        ),
        #  transforms.RandomRotation(degrees=(-10, 10)),
        transforms.ColorJitter(
            brightness=0.1,
            contrast=0.2,
            saturation=0,
            hue=0
        ),
        transforms.ToTensor(),
        transforms.RandomErasing(),
    ]),
    'valid': transforms.Compose([
        #  transforms.CenterCrop((512, 512)),
        #  transforms.CenterCrop((640, 640)),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ]),
    'aug': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(
            degrees=(-20, 20),
            translate=(0.1, 0.3),
            scale=(0.5, 0.75)
        ),
        transforms.ColorJitter(
            brightness=0.1,
            contrast=0.2,
            saturation=0,
            hue=0
        ),
    ]),
}

In [8]:
datasets = random_split(ratio=0.7)
for key in datasets:
    datasets[key].set_transform(tf[key])

## Model

In [9]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 11)
        )


    def forward(self, x):
        return self.model(x)

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class_sample_count = [
    1356,
    573,
    2000,
    1313,
    1174,
    1774,
    587,
    376,
    1202,
    2000,
    941
]
loss_weight = (1 / torch.Tensor(class_sample_count)).to(device)

loaders = {
    key: DataLoader(
        datasets[key],
        batch_size=128,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    ) for key in datasets
}

CHECKPOINT = 'best.ckpt'
#  CHECKPOINT = 'clean.ckpt'
model = Classifier().to(device)
# model.load_state_dict(torch.load(CHECKPOINT))
criterion = nn.CrossEntropyLoss(weight=loss_weight)

def accuracy(logits, label):
    return (logits.argmax(dim=-1) == label).float().mean()

optimizer = torch.optim.Adam(
    model.parameters(),
    #  lr=0.0005,
    # lr=0.00001,
    #  lr=0.0001,
    lr=0.001,
    weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 25, 40, 55, 70, 85], gamma=0.5)

## Main

In [None]:
N_EPOCHS = 300
n_train = len(loaders['train'])
n_valid = len(loaders['valid'])
valid_freq = 1

best_accu = 0.
stagnation = 0
PATIENCE = 20
N_TEST_TIME_AUG = 5

writer = SummaryWriter()

for epoch in range(N_EPOCHS):
    print()
    print('Epoch:', epoch)

    # training
    model.train()
    running_accu = 0.
    running_loss = 0.
    train_bar = tqdm(loaders['train'])
    for (image, label) in train_bar:
        image = image.to(device)
        label = label.to(device)
        logits = model(image)
        loss = criterion(logits, label)
        accu = accuracy(logits, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = loss.item()
        running_accu += accu
        running_loss += loss
        train_bar.set_description('[train] loss: %.3f, accu: %.3f' % (loss, accu))
    running_accu /= n_train
    running_loss /= n_train
    scheduler.step()
    print('[train] avg loss: %.3f, avg accu: %.3f, lr: %.5f' % (running_loss, running_accu, optimizer.param_groups[0]['lr']))
    writer.add_scalar('train/loss', running_loss, epoch)
    writer.add_scalar('train/accu', running_accu, epoch)

    if epoch % valid_freq == 0 and epoch > 15:
        # validation
        model.eval()
        running_accu = 0.
        running_loss = 0.
        valid_bar = tqdm(loaders['valid'])
        for (image, label) in valid_bar:
            image = image.to(device)
            label = label.to(device)
            with torch.no_grad():
                logits = model(image)
                if N_TEST_TIME_AUG > 0:
                    logits = nn.functional.softmax(logits, dim=1)
                    for _ in range(N_TEST_TIME_AUG):
                        logits += nn.functional.softmax(model(tf['aug'](image)), dim=1)
            accu = accuracy(logits, label).cpu()
            loss = criterion(logits, label).item()
            running_accu += accu
            running_loss += loss
            valid_bar.set_description('[valid] loss: %.3f, accu: %.3f' % (loss, accu))
        running_accu /= n_valid
        running_loss /= n_valid

        print('[valid] avg loss: %.3f, avg accu: %.3f' % (running_loss, running_accu))
        writer.add_scalar('valid/loss', running_loss, epoch)
        writer.add_scalar('valid/accu', running_accu, epoch)

        if running_accu > best_accu:
            torch.save(model.state_dict(), CHECKPOINT)
            best_accu = running_accu
            writer.add_scalar('best_accu', best_accu, epoch)
            stagnation = 0
        else:
            stagnation += 1
            if stagnation > PATIENCE:
                print('Early stopped')
                break


Epoch: 0


[train] loss: 2.211, accu: 0.148: 100%|██████████| 146/146 [04:33<00:00,  1.87s/it]


[train] avg loss: 2.397, avg accu: 0.135, lr: 0.00100

Epoch: 1


[train] loss: 2.427, accu: 0.130: 100%|██████████| 146/146 [04:38<00:00,  1.91s/it]


[train] avg loss: 2.211, avg accu: 0.177, lr: 0.00100

Epoch: 2


[train] loss: 1.928, accu: 0.204: 100%|██████████| 146/146 [04:35<00:00,  1.89s/it]


[train] avg loss: 2.093, avg accu: 0.229, lr: 0.00100

Epoch: 3


[train] loss: 1.785, accu: 0.222: 100%|██████████| 146/146 [04:36<00:00,  1.89s/it]


[train] avg loss: 1.929, avg accu: 0.274, lr: 0.00100

Epoch: 4


[train] loss: 1.962, accu: 0.315: 100%|██████████| 146/146 [04:38<00:00,  1.91s/it]


[train] avg loss: 1.838, avg accu: 0.317, lr: 0.00100

Epoch: 5


[train] loss: 1.609, accu: 0.444: 100%|██████████| 146/146 [04:35<00:00,  1.89s/it]


[train] avg loss: 1.731, avg accu: 0.350, lr: 0.00100

Epoch: 6


[train] loss: 2.063, accu: 0.278: 100%|██████████| 146/146 [04:35<00:00,  1.89s/it]


[train] avg loss: 1.650, avg accu: 0.380, lr: 0.00100

Epoch: 7


[train] loss: 1.634, accu: 0.444: 100%|██████████| 146/146 [04:33<00:00,  1.88s/it]


[train] avg loss: 1.578, avg accu: 0.415, lr: 0.00100

Epoch: 8


[train] loss: 1.350, accu: 0.389: 100%|██████████| 146/146 [04:35<00:00,  1.89s/it]


[train] avg loss: 1.513, avg accu: 0.433, lr: 0.00100

Epoch: 9


[train] loss: 1.334, accu: 0.444: 100%|██████████| 146/146 [04:35<00:00,  1.88s/it]


[train] avg loss: 1.450, avg accu: 0.461, lr: 0.00050

Epoch: 10


[train] loss: 1.209, accu: 0.463: 100%|██████████| 146/146 [04:33<00:00,  1.87s/it]


[train] avg loss: 1.328, avg accu: 0.505, lr: 0.00050

Epoch: 11


[train] loss: 1.308, accu: 0.574: 100%|██████████| 146/146 [04:30<00:00,  1.85s/it]


[train] avg loss: 1.278, avg accu: 0.521, lr: 0.00050

Epoch: 12


[train] loss: 1.172, accu: 0.556: 100%|██████████| 146/146 [04:24<00:00,  1.81s/it]


[train] avg loss: 1.242, avg accu: 0.537, lr: 0.00050

Epoch: 13


[train] loss: 1.390, accu: 0.500: 100%|██████████| 146/146 [04:26<00:00,  1.82s/it]


[train] avg loss: 1.207, avg accu: 0.545, lr: 0.00050

Epoch: 14


[train] loss: 0.929, accu: 0.685: 100%|██████████| 146/146 [04:25<00:00,  1.82s/it]


[train] avg loss: 1.181, avg accu: 0.558, lr: 0.00050

Epoch: 15


[train] loss: 1.123, accu: 0.630: 100%|██████████| 146/146 [04:28<00:00,  1.84s/it]


[train] avg loss: 1.152, avg accu: 0.569, lr: 0.00050

Epoch: 16


[train] loss: 1.330, accu: 0.444: 100%|██████████| 146/146 [04:28<00:00,  1.84s/it]


[train] avg loss: 1.116, avg accu: 0.584, lr: 0.00050


[valid] loss: 1.980, accu: 0.524: 100%|██████████| 32/32 [01:15<00:00,  2.37s/it]


[valid] avg loss: 1.389, avg accu: 0.572

Epoch: 17


[train] loss: 0.905, accu: 0.685: 100%|██████████| 146/146 [04:32<00:00,  1.86s/it]


[train] avg loss: 1.080, avg accu: 0.598, lr: 0.00050


[valid] loss: 1.316, accu: 0.524: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.414, avg accu: 0.566

Epoch: 18


[train] loss: 1.019, accu: 0.593: 100%|██████████| 146/146 [04:35<00:00,  1.88s/it]


[train] avg loss: 1.062, avg accu: 0.602, lr: 0.00050


[valid] loss: 1.358, accu: 0.667: 100%|██████████| 32/32 [01:15<00:00,  2.37s/it]


[valid] avg loss: 1.222, avg accu: 0.625

Epoch: 19


[train] loss: 1.246, accu: 0.519: 100%|██████████| 146/146 [04:32<00:00,  1.87s/it]


[train] avg loss: 1.027, avg accu: 0.613, lr: 0.00050


[valid] loss: 0.734, accu: 0.810: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.477, avg accu: 0.569

Epoch: 20


[train] loss: 1.077, accu: 0.611: 100%|██████████| 146/146 [04:36<00:00,  1.89s/it]


[train] avg loss: 1.007, avg accu: 0.623, lr: 0.00050


[valid] loss: 0.970, accu: 0.714: 100%|██████████| 32/32 [01:15<00:00,  2.34s/it]


[valid] avg loss: 1.541, avg accu: 0.553

Epoch: 21


[train] loss: 0.833, accu: 0.685: 100%|██████████| 146/146 [04:39<00:00,  1.91s/it]


[train] avg loss: 0.972, avg accu: 0.635, lr: 0.00050


[valid] loss: 1.590, accu: 0.524: 100%|██████████| 32/32 [01:17<00:00,  2.42s/it]


[valid] avg loss: 1.421, avg accu: 0.550

Epoch: 22


[train] loss: 0.800, accu: 0.722: 100%|██████████| 146/146 [04:36<00:00,  1.90s/it]


[train] avg loss: 0.959, avg accu: 0.636, lr: 0.00050


[valid] loss: 1.097, accu: 0.619: 100%|██████████| 32/32 [01:16<00:00,  2.39s/it]


[valid] avg loss: 1.127, avg accu: 0.623

Epoch: 23


[train] loss: 0.947, accu: 0.630: 100%|██████████| 146/146 [04:37<00:00,  1.90s/it]


[train] avg loss: 0.935, avg accu: 0.649, lr: 0.00050


[valid] loss: 1.297, accu: 0.571: 100%|██████████| 32/32 [01:15<00:00,  2.37s/it]


[valid] avg loss: 1.280, avg accu: 0.586

Epoch: 24


[train] loss: 1.034, accu: 0.630: 100%|██████████| 146/146 [04:38<00:00,  1.91s/it]


[train] avg loss: 0.914, avg accu: 0.656, lr: 0.00025


[valid] loss: 1.245, accu: 0.619: 100%|██████████| 32/32 [01:16<00:00,  2.38s/it]


[valid] avg loss: 1.302, avg accu: 0.630

Epoch: 25


[train] loss: 0.419, accu: 0.815: 100%|██████████| 146/146 [04:36<00:00,  1.90s/it]


[train] avg loss: 0.836, avg accu: 0.682, lr: 0.00025


[valid] loss: 0.802, accu: 0.762: 100%|██████████| 32/32 [01:15<00:00,  2.35s/it]


[valid] avg loss: 1.133, avg accu: 0.663

Epoch: 26


[train] loss: 0.837, accu: 0.667: 100%|██████████| 146/146 [04:39<00:00,  1.91s/it]


[train] avg loss: 0.789, avg accu: 0.695, lr: 0.00025


[valid] loss: 0.821, accu: 0.714: 100%|██████████| 32/32 [01:15<00:00,  2.37s/it]


[valid] avg loss: 1.158, avg accu: 0.656

Epoch: 27


[train] loss: 0.757, accu: 0.704: 100%|██████████| 146/146 [04:33<00:00,  1.88s/it]


[train] avg loss: 0.788, avg accu: 0.697, lr: 0.00025


[valid] loss: 0.719, accu: 0.762: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.048, avg accu: 0.700

Epoch: 28


[train] loss: 0.611, accu: 0.759: 100%|██████████| 146/146 [04:39<00:00,  1.92s/it]


[train] avg loss: 0.773, avg accu: 0.703, lr: 0.00025


[valid] loss: 0.746, accu: 0.762: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.191, avg accu: 0.664

Epoch: 29


[train] loss: 0.757, accu: 0.759: 100%|██████████| 146/146 [04:42<00:00,  1.94s/it]


[train] avg loss: 0.768, avg accu: 0.706, lr: 0.00025


[valid] loss: 0.510, accu: 0.857: 100%|██████████| 32/32 [01:16<00:00,  2.38s/it]


[valid] avg loss: 1.111, avg accu: 0.679

Epoch: 30


[train] loss: 0.814, accu: 0.648: 100%|██████████| 146/146 [04:35<00:00,  1.89s/it]


[train] avg loss: 0.738, avg accu: 0.715, lr: 0.00025


[valid] loss: 1.342, accu: 0.619: 100%|██████████| 32/32 [01:17<00:00,  2.43s/it]


[valid] avg loss: 1.046, avg accu: 0.694

Epoch: 31


[train] loss: 0.652, accu: 0.815: 100%|██████████| 146/146 [04:39<00:00,  1.91s/it]


[train] avg loss: 0.721, avg accu: 0.720, lr: 0.00025


[valid] loss: 0.658, accu: 0.714: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.245, avg accu: 0.665

Epoch: 32


[train] loss: 0.742, accu: 0.741: 100%|██████████| 146/146 [04:38<00:00,  1.91s/it]


[train] avg loss: 0.719, avg accu: 0.723, lr: 0.00025


[valid] loss: 1.295, accu: 0.476: 100%|██████████| 32/32 [01:15<00:00,  2.36s/it]


[valid] avg loss: 1.062, avg accu: 0.671

Epoch: 33


[train] loss: 0.677, accu: 0.852: 100%|██████████| 146/146 [04:42<00:00,  1.93s/it]


[train] avg loss: 0.701, avg accu: 0.729, lr: 0.00025


[valid] loss: 1.180, accu: 0.667: 100%|██████████| 32/32 [01:17<00:00,  2.44s/it]


[valid] avg loss: 1.018, avg accu: 0.700

Epoch: 34


[train] loss: 0.745, accu: 0.667: 100%|██████████| 146/146 [04:38<00:00,  1.91s/it]


[train] avg loss: 0.696, avg accu: 0.732, lr: 0.00025


[valid] loss: 0.675, accu: 0.857: 100%|██████████| 32/32 [01:17<00:00,  2.43s/it]


[valid] avg loss: 1.180, avg accu: 0.663

Epoch: 35


[train] loss: 0.790, accu: 0.685: 100%|██████████| 146/146 [04:44<00:00,  1.95s/it]


[train] avg loss: 0.675, avg accu: 0.736, lr: 0.00025


[valid] loss: 1.367, accu: 0.714: 100%|██████████| 32/32 [01:16<00:00,  2.40s/it]


[valid] avg loss: 1.334, avg accu: 0.671

Epoch: 36


[train] loss: 0.698, accu: 0.778: 100%|██████████| 146/146 [04:43<00:00,  1.94s/it]


[train] avg loss: 0.681, avg accu: 0.736, lr: 0.00025


[valid] loss: 1.261, accu: 0.714: 100%|██████████| 32/32 [01:17<00:00,  2.41s/it]


[valid] avg loss: 1.090, avg accu: 0.716

Epoch: 37


[train] loss: 0.501, accu: 0.759: 100%|██████████| 146/146 [04:43<00:00,  1.94s/it]


[train] avg loss: 0.655, avg accu: 0.748, lr: 0.00025


[valid] loss: 1.539, accu: 0.571: 100%|██████████| 32/32 [01:17<00:00,  2.43s/it]


[valid] avg loss: 1.070, avg accu: 0.693

Epoch: 38


[train] loss: 0.569, accu: 0.815: 100%|██████████| 146/146 [04:44<00:00,  1.95s/it]


[train] avg loss: 0.655, avg accu: 0.743, lr: 0.00025


[valid] loss: 1.529, accu: 0.571: 100%|██████████| 32/32 [01:16<00:00,  2.39s/it]


[valid] avg loss: 1.087, avg accu: 0.694

Epoch: 39


[train] loss: 0.605, accu: 0.815: 100%|██████████| 146/146 [04:41<00:00,  1.93s/it]


[train] avg loss: 0.641, avg accu: 0.749, lr: 0.00013


[valid] loss: 1.312, accu: 0.619: 100%|██████████| 32/32 [01:17<00:00,  2.41s/it]


[valid] avg loss: 1.035, avg accu: 0.704

Epoch: 40


[train] loss: 0.494, accu: 0.815: 100%|██████████| 146/146 [04:39<00:00,  1.92s/it]


[train] avg loss: 0.594, avg accu: 0.761, lr: 0.00013


[valid] loss: 1.105, accu: 0.714: 100%|██████████| 32/32 [01:16<00:00,  2.39s/it]


[valid] avg loss: 1.114, avg accu: 0.708

Epoch: 41


[train] loss: 0.609, accu: 0.789:  82%|████████▏ | 120/146 [03:51<00:50,  1.94s/it]

In [None]:
test_set = FoodDataset(
    './food11/test',
    transform=tf['valid'],
    testing=True,
)
test_loader = DataLoader(
    test_set,
    batch_size=128,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [None]:
model_best = Classifier().to(device)
model_best.load_state_dict(torch.load(CHECKPOINT))
model_best.eval()
prediction = []
with torch.no_grad():
    for (image, _) in test_loader:
        image = image.to(device)
        logits = model_best(image)
        if N_TEST_TIME_AUG > 0:
            logits = nn.functional.softmax(logits, dim=1)
            for _ in range(N_TEST_TIME_AUG):
                logits += nn.functional.softmax(model_best(tf['aug'](image)), dim=1)
        pred = np.argmax(
            logits.cpu().data.numpy(),
            axis=1
        )
        prediction += pred.squeeze().tolist()

def pad4(i):
    return '0' * (4 - len(str(i))) + str(i)

df = pd.DataFrame()
df['Id'] = [pad4(i) for i in range(1, len(test_set) + 1)]
df['Category'] = prediction
df.to_csv('submission.csv',index = False)