# FOOD IMAGE CLASSFICATION

食物图片分类

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip -q /content/drive/MyDrive/ML2022/food-11.zip

Mounted at /content/drive


## 导入包

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from tqdm import tqdm
import matplotlib.pyplot as plt

## 数据集的导入

### Augmentation

In [None]:
train_tfm = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

test_tfm = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

### `Dataset` & `DataLoader`

In [None]:
batch_size = 128

train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

## 模型设置

### 简单 CNN

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        # torch.nn.MaxPool2d(kernel_size, stride, padding)
        # n = (W - kernel_size + 2 * padding) / stride + 1

        # input image size: [3, 224, 224]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(4, 4, 0),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 14 * 14, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 224, 224]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

        # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x

### Config

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize a model, and put it on the device specified.
model = Classifier().to(device)
model.device = device

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

# The number of training epochs.
n_epochs = 100

## 训练

In [None]:
final_acc = []
train_final_acc = []
epochs = [i+1 for i in range(n_epochs)]
best_acc = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):

        imgs, labels = batch

        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Update the parameters with computed gradients.
        optimizer.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)
    train_final_acc.append(train_acc.item())
    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ---------- Validation ----------
    model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))

        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    if len(final_acc) == 0 or valid_acc >= max(final_acc):
        best_acc = valid_acc
        print(f"[ BEST Valid acc | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    final_acc.append(valid_acc.item())
    torch.save(model.state_dict(), './model.pth')
    torch.save(optimizer.state_dict(), './optimizer.pth')

100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 001/100 ] loss = 2.70709, acc = 0.16437


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]


[ BEST Valid acc | 001/100 ] loss = 3.16860, acc = 0.18906


100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 002/100 ] loss = 1.99226, acc = 0.28062


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 003/100 ] loss = 1.89130, acc = 0.33531


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


[ BEST Valid acc | 003/100 ] loss = 7.97834, acc = 0.22969


100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 004/100 ] loss = 1.77952, acc = 0.37000


100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 005/100 ] loss = 1.61020, acc = 0.44656


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 006/100 ] loss = 1.54198, acc = 0.46469


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 007/100 ] loss = 1.39017, acc = 0.54188


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 008/100 ] loss = 1.21895, acc = 0.59969


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 009/100 ] loss = 1.12199, acc = 0.62062


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


[ BEST Valid acc | 009/100 ] loss = 13.00376, acc = 0.23281


100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 010/100 ] loss = 0.97557, acc = 0.68656


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 011/100 ] loss = 0.87769, acc = 0.71156


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 012/100 ] loss = 0.70699, acc = 0.78156


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 013/100 ] loss = 0.60324, acc = 0.80281


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 014/100 ] loss = 0.50331, acc = 0.84750


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 015/100 ] loss = 0.42428, acc = 0.87500


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


[ BEST Valid acc | 015/100 ] loss = 14.89803, acc = 0.25495


100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 016/100 ] loss = 0.42740, acc = 0.87469


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 017/100 ] loss = 0.33294, acc = 0.91094


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 018/100 ] loss = 0.22328, acc = 0.94406


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 019/100 ] loss = 0.17750, acc = 0.95875


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 020/100 ] loss = 0.11581, acc = 0.98312


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 021/100 ] loss = 0.12346, acc = 0.96969


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 022/100 ] loss = 0.10125, acc = 0.97781


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 023/100 ] loss = 0.08225, acc = 0.98656


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 024/100 ] loss = 0.05305, acc = 0.99125


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 025/100 ] loss = 0.08698, acc = 0.97938


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 026/100 ] loss = 0.07621, acc = 0.98531


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 027/100 ] loss = 0.05323, acc = 0.99187


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 028/100 ] loss = 0.03464, acc = 0.99687


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 029/100 ] loss = 0.04972, acc = 0.99094


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 030/100 ] loss = 0.04013, acc = 0.98937


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 031/100 ] loss = 0.06856, acc = 0.98469


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 032/100 ] loss = 0.09132, acc = 0.97406


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 033/100 ] loss = 0.08992, acc = 0.97719


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 034/100 ] loss = 0.06027, acc = 0.98500


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 035/100 ] loss = 0.04195, acc = 0.99375


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


[ Train | 036/100 ] loss = 0.04550, acc = 0.98750


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 037/100 ] loss = 0.03755, acc = 0.99281


100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 038/100 ] loss = 0.01558, acc = 0.99906


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 039/100 ] loss = 0.01646, acc = 0.99469


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 040/100 ] loss = 0.02784, acc = 0.99312


100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 041/100 ] loss = 0.01524, acc = 0.99875


100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 042/100 ] loss = 0.01360, acc = 0.99875


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 043/100 ] loss = 0.01887, acc = 0.99719


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 044/100 ] loss = 0.02299, acc = 0.99437


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 045/100 ] loss = 0.04775, acc = 0.98594


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 046/100 ] loss = 0.01903, acc = 0.99687


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 047/100 ] loss = 0.00739, acc = 0.99906


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 048/100 ] loss = 0.00497, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 049/100 ] loss = 0.05000, acc = 0.98219


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 050/100 ] loss = 0.07084, acc = 0.97688


100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 051/100 ] loss = 0.09431, acc = 0.97062


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 052/100 ] loss = 0.08538, acc = 0.97094


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 053/100 ] loss = 0.03325, acc = 0.98781


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 054/100 ] loss = 0.03090, acc = 0.99312


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 055/100 ] loss = 0.01393, acc = 0.99750


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 056/100 ] loss = 0.00785, acc = 0.99969


100%|██████████| 6/6 [00:02<00:00,  2.24it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 057/100 ] loss = 0.09871, acc = 0.97062


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 058/100 ] loss = 0.05008, acc = 0.98375


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


[ Train | 059/100 ] loss = 0.03268, acc = 0.98687


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 060/100 ] loss = 0.03970, acc = 0.98750


100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 061/100 ] loss = 0.05767, acc = 0.98438


100%|██████████| 6/6 [00:02<00:00,  2.28it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 062/100 ] loss = 0.03442, acc = 0.99406


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 063/100 ] loss = 0.04743, acc = 0.98375


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 064/100 ] loss = 0.02746, acc = 0.99250


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 065/100 ] loss = 0.09845, acc = 0.96812


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


[ Train | 066/100 ] loss = 0.06398, acc = 0.98750


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 067/100 ] loss = 0.04521, acc = 0.98500


100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 068/100 ] loss = 0.05030, acc = 0.98875


100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 069/100 ] loss = 0.02320, acc = 0.99219


100%|██████████| 6/6 [00:02<00:00,  2.28it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 070/100 ] loss = 0.00542, acc = 0.99937


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


[ Train | 071/100 ] loss = 0.00193, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 072/100 ] loss = 0.00284, acc = 0.99969


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 073/100 ] loss = 0.00176, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 074/100 ] loss = 0.00086, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 075/100 ] loss = 0.00067, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 076/100 ] loss = 0.00071, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 077/100 ] loss = 0.06271, acc = 0.98969


100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 078/100 ] loss = 0.07775, acc = 0.97719


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


[ Train | 079/100 ] loss = 0.08284, acc = 0.97312


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.02it/s]


[ Train | 080/100 ] loss = 0.04143, acc = 0.98594


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 081/100 ] loss = 0.02718, acc = 0.99125


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 082/100 ] loss = 0.02112, acc = 0.99375


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 083/100 ] loss = 0.01113, acc = 0.99625


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 084/100 ] loss = 0.00589, acc = 0.99844


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 085/100 ] loss = 0.00235, acc = 0.99969


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 25/25 [00:12<00:00,  2.08it/s]


[ Train | 086/100 ] loss = 0.00171, acc = 0.99969


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 087/100 ] loss = 0.00074, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 088/100 ] loss = 0.00052, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.28it/s]
100%|██████████| 25/25 [00:12<00:00,  2.00it/s]


[ Train | 089/100 ] loss = 0.00233, acc = 1.00000


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 090/100 ] loss = 0.04297, acc = 0.98719


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 25/25 [00:12<00:00,  2.03it/s]


[ Train | 091/100 ] loss = 0.03636, acc = 0.98906


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 092/100 ] loss = 0.03471, acc = 0.98781


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.06it/s]


[ Train | 093/100 ] loss = 0.07228, acc = 0.98156


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 094/100 ] loss = 0.09875, acc = 0.97000


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 095/100 ] loss = 0.12898, acc = 0.95719


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 096/100 ] loss = 0.10397, acc = 0.96219


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 097/100 ] loss = 0.12092, acc = 0.96625


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 25/25 [00:12<00:00,  2.07it/s]


[ Train | 098/100 ] loss = 0.03982, acc = 0.98844


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 25/25 [00:12<00:00,  2.05it/s]


[ Train | 099/100 ] loss = 0.01564, acc = 0.99531


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 25/25 [00:12<00:00,  2.04it/s]


[ Train | 100/100 ] loss = 0.00785, acc = 0.99781


100%|██████████| 6/6 [00:02<00:00,  2.41it/s]


## 结果输出：

In [None]:
fig = plt.figure()
plt.plot(epochs, train_final_acc, color='blue')
plt.plot(epochs, final_acc, color='red')
plt.legend(['Train Acc', 'Test Acc'], loc='upper right')
plt.xlabel('Epochs Number')
plt.ylabel('Accuracy')
plt.title('Accuracy')
fig.savefig("res.png")
plt.show()


with open("prediction.csv", "w") as f:
    f.write("Method,Acc\n")
    f.write(f"CNN whith no supervised learning,{best_acc:.3f}\n")

model.eval()

predictions = []

for batch in tqdm(test_loader):
    imgs, labels = batch

    with torch.no_grad():
        logits = model(imgs.to(device))

    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

with open("predict.csv", "w") as f:

    f.write("Id,Category\n")
    for i, pred in  enumerate(predictions):
        f.write(f"{i},{pred}\n")