In [1]:
# Download the dataset
!gdown --id '1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy' --output food-11.zip
!unzip -q food-11.zip

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder

# tqdm 顯示進度條的工具
from tqdm.auto import tqdm

Downloading...
From: https://drive.google.com/uc?id=1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy
To: /content/food-11.zip
100% 963M/963M [00:17<00:00, 56.5MB/s]


In [2]:
import torchvision.transforms as transforms

# 在訓練中進行數據增強
train_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),   # 隨機水平翻轉
    transforms.RandomRotation(15),       # 隨機旋轉（角度為15度）
    transforms.RandomCrop((120, 120)),   # 隨機裁剪到120x120大小
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),   # 隨機顏色調整
    transforms.ToTensor(),
])

# 測試和驗證中只進行圖像轉換
test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])


In [3]:
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 64

# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)



In [4]:

import torchvision.models as models
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        # 載入 ResNet-18 模型
        self.resnet = models.resnet18(pretrained=False)

        # 替換最後一層全連接層
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, 11)

    def forward(self, x):
        x = self.resnet(x)
        return x


In [5]:
def get_pseudo_labels(dataset, model, threshold=0.65):
    # This functions generates pseudo-labels of a dataset using given model.
    # It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    # You are NOT allowed to use any models trained on external data for pseudo-labeling.
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Construct a data loader.
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Make sure the model is in eval mode.
    model.eval()
    # Define softmax function.
    softmax = nn.Softmax(dim=-1)

    # Iterate over the dataset by batches.
    for batch in tqdm(data_loader):
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))

        # Obtain the probability distributions by applying softmax on logits.
        probs = softmax(logits)

        # ---------- TODO ----------
        # Filter the data and construct a new dataset.

    # # Turn off the eval mode.
    model.train()
    return dataset

In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

# 假設已經定義了以下函式和變數：
# get_pseudo_labels(unlabeled_set, model): 給定未標記數據和訓練好的模型，返回使用模型生成的伪標籤的數據集。
# train_set: 有標籤的訓練數據集。
# unlabeled_set: 未標記的數據集。
# valid_loader: 驗證數據的 DataLoader。
# batch_size: 批次大小。

# "cuda" 只有當有 GPU 可用時使用。
device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化模型並將其放置在指定的設備上。
model = Classifier().to(device)
model.device = device

# 對於分類任務，我們使用交叉熵作為性能衡量標準。
criterion = nn.CrossEntropyLoss()

# 初始化優化器，您可以自行微調一些超參數，如學習率。
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

# 訓練的 epoch 數量。
n_epochs = 80

# 是否進行半監督學習。
do_semi = False

for epoch in range(n_epochs):
    # ---------- TODO ----------
    # 在每個 epoch 中，為半監督學習重新標記未標記數據集。
    # 然後，您可以將有標籤數據集和伪標籤數據集結合起來進行訓練。
    if do_semi:
        # 使用訓練好的模型為未標記數據獲取伪標籤。
        pseudo_set = get_pseudo_labels(unlabeled_set, model)

        # 構建新的數據集和數據加載器用於訓練。
        # 這僅在半監督學習中使用。
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

    # ---------- 訓練 ----------
    # 在訓練之前確保模型處於訓練模式。
    model.train()

    # 這些變量用於記錄訓練過程中的信息。
    train_loss = []
    train_accs = []

    # 逐批次迭代訓練數據集。
    for batch in tqdm(train_loader):

        # 一批次包含圖像數據和相應的標籤。
        imgs, labels = batch

        # 將數據傳遞到模型中（確保數據和模型在同一設備上）。
        logits = model(imgs.to(device))

        # 計算交叉熵損失。
        # 在計算交叉熵之前，我們不需要在 logits 上應用 softmax，因為它會自動完成。
        loss = criterion(logits, labels.to(device))

        # 清除先前步驟中參數中存儲的梯度。
        optimizer.zero_grad()

        # 計算參數的梯度。
        loss.backward()

        # 限制梯度的範數以實現穩定訓練。
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # 使用計算出的梯度更新參數。
        optimizer.step()

        # 計算當前批次的準確率。
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 記錄損失和準確率。
        train_loss.append(loss.item())
        train_accs.append(acc)

    # 訓練集的平均損失和準確率是記錄值的平均值。
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # 輸出訓練信息。
    print(f"[ 訓練 | {epoch + 1:03d}/{n_epochs:03d} ] 損失 = {train_loss:.5f}, 準確率 = {train_acc:.5f}")

    # ---------- 驗證 ----------
    # 確保模型處於評估模式，以便一些模塊如 dropout 被禁用並正常工作。
    model.eval()

    # 這些變量用於記錄驗證過程中的信息。
    valid_loss = []
    valid_accs = []

    # 逐批次迭代驗證數據集。
    for batch in tqdm(valid_loader):

        # 一批次包含圖像數據和相應的標籤。
        imgs, labels = batch

        # 在驗證過程中我們不需要梯度。
        # 使用 torch.no_grad() 可以加速前向過程。
        with torch.no_grad():
            logits = model(imgs.to(device))

        # 我們仍然可以計算損失（但沒有梯度）。
        loss = criterion(logits, labels.to(device))

        # 計算當前批次的準確率。
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 記錄損失和準確率。
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # 整個驗證集的平均損失和準確率是記錄值的平均值。
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # 輸出驗證信息。
    print(f"[ 驗證 | {epoch + 1:03d}/{n_epochs:03d} ] 損失 = {valid_loss:.5f}, 準確率 = {valid_acc:.5f}")


100%|██████████| 49/49 [00:35<00:00,  1.40it/s]


[ 訓練 | 001/080 ] 損失 = 2.28793, 準確率 = 0.18112


100%|██████████| 11/11 [00:05<00:00,  1.96it/s]


[ 驗證 | 001/080 ] 損失 = 2.87774, 準確率 = 0.17670


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 002/080 ] 損失 = 2.05977, 準確率 = 0.27392


100%|██████████| 11/11 [00:06<00:00,  1.64it/s]


[ 驗證 | 002/080 ] 損失 = 2.11986, 準確率 = 0.26136


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 003/080 ] 損失 = 1.94652, 準確率 = 0.31983


100%|██████████| 11/11 [00:06<00:00,  1.73it/s]


[ 驗證 | 003/080 ] 損失 = 2.06885, 準確率 = 0.29545


100%|██████████| 49/49 [00:30<00:00,  1.62it/s]


[ 訓練 | 004/080 ] 損失 = 1.86542, 準確率 = 0.33514


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 004/080 ] 損失 = 2.05259, 準確率 = 0.29801


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 005/080 ] 損失 = 1.79919, 準確率 = 0.37245


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 005/080 ] 損失 = 2.18315, 準確率 = 0.27386


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


[ 訓練 | 006/080 ] 損失 = 1.69637, 準確率 = 0.40338


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 006/080 ] 損失 = 1.71691, 準確率 = 0.41108


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 007/080 ] 損失 = 1.62096, 準確率 = 0.43144


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 007/080 ] 損失 = 1.78315, 準確率 = 0.40000


100%|██████████| 49/49 [00:30<00:00,  1.60it/s]


[ 訓練 | 008/080 ] 損失 = 1.62968, 準確率 = 0.42570


100%|██████████| 11/11 [00:06<00:00,  1.63it/s]


[ 驗證 | 008/080 ] 損失 = 1.75587, 準確率 = 0.38807


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 009/080 ] 損失 = 1.54089, 準確率 = 0.45631


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 009/080 ] 損失 = 1.73740, 準確率 = 0.41591


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 010/080 ] 損失 = 1.50913, 準確率 = 0.48565


100%|██████████| 11/11 [00:05<00:00,  1.95it/s]


[ 驗證 | 010/080 ] 損失 = 1.84337, 準確率 = 0.38920


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 011/080 ] 損失 = 1.44320, 準確率 = 0.49936


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 011/080 ] 損失 = 2.22301, 準確率 = 0.31761


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 012/080 ] 損失 = 1.37988, 準確率 = 0.51786


100%|██████████| 11/11 [00:05<00:00,  1.91it/s]


[ 驗證 | 012/080 ] 損失 = 2.09208, 準確率 = 0.31449


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 013/080 ] 損失 = 1.32165, 準確率 = 0.55102


100%|██████████| 11/11 [00:06<00:00,  1.66it/s]


[ 驗證 | 013/080 ] 損失 = 1.73576, 準確率 = 0.41108


100%|██████████| 49/49 [00:30<00:00,  1.58it/s]


[ 訓練 | 014/080 ] 損失 = 1.27256, 準確率 = 0.55899


100%|██████████| 11/11 [00:07<00:00,  1.48it/s]


[ 驗證 | 014/080 ] 損失 = 1.88122, 準確率 = 0.40170


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 015/080 ] 損失 = 1.21484, 準確率 = 0.57781


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 015/080 ] 損失 = 1.78207, 準確率 = 0.45710


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 016/080 ] 損失 = 1.17649, 準確率 = 0.59407


100%|██████████| 11/11 [00:07<00:00,  1.50it/s]


[ 驗證 | 016/080 ] 損失 = 1.63987, 準確率 = 0.46733


100%|██████████| 49/49 [00:29<00:00,  1.66it/s]


[ 訓練 | 017/080 ] 損失 = 1.12439, 準確率 = 0.62054


100%|██████████| 11/11 [00:06<00:00,  1.76it/s]


[ 驗證 | 017/080 ] 損失 = 1.66086, 準確率 = 0.48040


100%|██████████| 49/49 [00:29<00:00,  1.66it/s]


[ 訓練 | 018/080 ] 損失 = 1.09468, 準確率 = 0.62213


100%|██████████| 11/11 [00:05<00:00,  1.96it/s]


[ 驗證 | 018/080 ] 損失 = 2.69344, 準確率 = 0.36960


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 019/080 ] 損失 = 1.00170, 準確率 = 0.65912


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 019/080 ] 損失 = 1.77774, 準確率 = 0.42955


100%|██████████| 49/49 [00:30<00:00,  1.60it/s]


[ 訓練 | 020/080 ] 損失 = 0.98786, 準確率 = 0.65338


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 020/080 ] 損失 = 2.13673, 準確率 = 0.40795


100%|██████████| 49/49 [00:30<00:00,  1.62it/s]


[ 訓練 | 021/080 ] 損失 = 0.89629, 準確率 = 0.69101


100%|██████████| 11/11 [00:06<00:00,  1.80it/s]


[ 驗證 | 021/080 ] 損失 = 1.95492, 準確率 = 0.40938


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 022/080 ] 損失 = 0.84854, 準確率 = 0.71205


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 022/080 ] 損失 = 1.72694, 準確率 = 0.47415


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 023/080 ] 損失 = 0.77720, 準確率 = 0.73182


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 023/080 ] 損失 = 1.73426, 準確率 = 0.50483


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 024/080 ] 損失 = 0.80166, 準確率 = 0.72258


100%|██████████| 11/11 [00:07<00:00,  1.49it/s]


[ 驗證 | 024/080 ] 損失 = 1.86387, 準確率 = 0.47841


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 025/080 ] 損失 = 0.70966, 準確率 = 0.75223


100%|██████████| 11/11 [00:06<00:00,  1.73it/s]


[ 驗證 | 025/080 ] 損失 = 2.80053, 準確率 = 0.36193


100%|██████████| 49/49 [00:31<00:00,  1.56it/s]


[ 訓練 | 026/080 ] 損失 = 0.73356, 準確率 = 0.74585


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 026/080 ] 損失 = 1.97814, 準確率 = 0.45483


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 027/080 ] 損失 = 0.66310, 準確率 = 0.77041


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 027/080 ] 損失 = 2.28179, 準確率 = 0.40369


100%|██████████| 49/49 [00:29<00:00,  1.68it/s]


[ 訓練 | 028/080 ] 損失 = 0.66581, 準確率 = 0.77615


100%|██████████| 11/11 [00:05<00:00,  1.89it/s]


[ 驗證 | 028/080 ] 損失 = 2.02337, 準確率 = 0.46080


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 029/080 ] 損失 = 0.61349, 準確率 = 0.78699


100%|██████████| 11/11 [00:05<00:00,  1.91it/s]


[ 驗證 | 029/080 ] 損失 = 2.55586, 準確率 = 0.35227


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 030/080 ] 損失 = 0.52283, 準確率 = 0.82972


100%|██████████| 11/11 [00:07<00:00,  1.47it/s]


[ 驗證 | 030/080 ] 損失 = 1.90347, 準確率 = 0.47415


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 031/080 ] 損失 = 0.50950, 準確率 = 0.82653


100%|██████████| 11/11 [00:05<00:00,  1.87it/s]


[ 驗證 | 031/080 ] 損失 = 2.13779, 準確率 = 0.43551


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 032/080 ] 損失 = 0.49201, 準確率 = 0.83450


100%|██████████| 11/11 [00:06<00:00,  1.79it/s]


[ 驗證 | 032/080 ] 損失 = 1.89956, 準確率 = 0.51080


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 033/080 ] 損失 = 0.46238, 準確率 = 0.84311


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 033/080 ] 損失 = 2.36516, 準確率 = 0.42074


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 034/080 ] 損失 = 0.46055, 準確率 = 0.84088


100%|██████████| 11/11 [00:05<00:00,  1.89it/s]


[ 驗證 | 034/080 ] 損失 = 2.09620, 準確率 = 0.47074


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 035/080 ] 損失 = 0.41942, 準確率 = 0.86129


100%|██████████| 11/11 [00:07<00:00,  1.49it/s]


[ 驗證 | 035/080 ] 損失 = 1.91323, 準確率 = 0.47983


100%|██████████| 49/49 [00:31<00:00,  1.58it/s]


[ 訓練 | 036/080 ] 損失 = 0.45251, 準確率 = 0.85108


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 036/080 ] 損失 = 2.08035, 準確率 = 0.47301


100%|██████████| 49/49 [00:29<00:00,  1.64it/s]


[ 訓練 | 037/080 ] 損失 = 0.37574, 準確率 = 0.87213


100%|██████████| 11/11 [00:05<00:00,  1.87it/s]


[ 驗證 | 037/080 ] 損失 = 2.22943, 準確率 = 0.46392


100%|██████████| 49/49 [00:29<00:00,  1.66it/s]


[ 訓練 | 038/080 ] 損失 = 0.36107, 準確率 = 0.88265


100%|██████████| 11/11 [00:07<00:00,  1.49it/s]


[ 驗證 | 038/080 ] 損失 = 2.21710, 準確率 = 0.41193


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


[ 訓練 | 039/080 ] 損失 = 0.35977, 準確率 = 0.88520


100%|██████████| 11/11 [00:06<00:00,  1.66it/s]


[ 驗證 | 039/080 ] 損失 = 2.13889, 準確率 = 0.48125


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 040/080 ] 損失 = 0.36012, 準確率 = 0.88393


100%|██████████| 11/11 [00:05<00:00,  1.88it/s]


[ 驗證 | 040/080 ] 損失 = 2.13863, 準確率 = 0.46165


100%|██████████| 49/49 [00:29<00:00,  1.66it/s]


[ 訓練 | 041/080 ] 損失 = 0.32195, 準確率 = 0.88839


100%|██████████| 11/11 [00:07<00:00,  1.46it/s]


[ 驗證 | 041/080 ] 損失 = 2.38443, 準確率 = 0.45312


100%|██████████| 49/49 [00:31<00:00,  1.57it/s]


[ 訓練 | 042/080 ] 損失 = 0.29329, 準確率 = 0.90051


100%|██████████| 11/11 [00:06<00:00,  1.75it/s]


[ 驗證 | 042/080 ] 損失 = 2.30126, 準確率 = 0.45483


100%|██████████| 49/49 [00:30<00:00,  1.62it/s]


[ 訓練 | 043/080 ] 損失 = 0.30049, 準確率 = 0.90147


100%|██████████| 11/11 [00:05<00:00,  1.93it/s]


[ 驗證 | 043/080 ] 損失 = 2.23684, 準確率 = 0.46364


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 044/080 ] 損失 = 0.29980, 準確率 = 0.89892


100%|██████████| 11/11 [00:07<00:00,  1.49it/s]


[ 驗證 | 044/080 ] 損失 = 2.06016, 準確率 = 0.47500


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 045/080 ] 損失 = 0.27536, 準確率 = 0.90370


100%|██████████| 11/11 [00:05<00:00,  1.91it/s]


[ 驗證 | 045/080 ] 損失 = 2.12504, 準確率 = 0.46392


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


[ 訓練 | 046/080 ] 損失 = 0.26490, 準確率 = 0.91550


100%|██████████| 11/11 [00:06<00:00,  1.72it/s]


[ 驗證 | 046/080 ] 損失 = 2.31183, 準確率 = 0.48040


100%|██████████| 49/49 [00:29<00:00,  1.68it/s]


[ 訓練 | 047/080 ] 損失 = 0.23050, 準確率 = 0.92188


100%|██████████| 11/11 [00:07<00:00,  1.50it/s]


[ 驗證 | 047/080 ] 損失 = 2.50511, 準確率 = 0.45000


100%|██████████| 49/49 [00:29<00:00,  1.66it/s]


[ 訓練 | 048/080 ] 損失 = 0.23457, 準確率 = 0.91996


100%|██████████| 11/11 [00:05<00:00,  1.92it/s]


[ 驗證 | 048/080 ] 損失 = 2.93763, 準確率 = 0.39801


100%|██████████| 49/49 [00:31<00:00,  1.58it/s]


[ 訓練 | 049/080 ] 損失 = 0.29131, 準確率 = 0.89987


100%|██████████| 11/11 [00:07<00:00,  1.50it/s]


[ 驗證 | 049/080 ] 損失 = 2.36885, 準確率 = 0.48295


100%|██████████| 49/49 [00:29<00:00,  1.68it/s]


[ 訓練 | 050/080 ] 損失 = 0.23018, 準確率 = 0.92379


100%|██████████| 11/11 [00:05<00:00,  1.94it/s]


[ 驗證 | 050/080 ] 損失 = 2.13255, 準確率 = 0.49773


100%|██████████| 49/49 [00:29<00:00,  1.64it/s]


[ 訓練 | 051/080 ] 損失 = 0.21474, 準確率 = 0.93017


100%|██████████| 11/11 [00:06<00:00,  1.76it/s]


[ 驗證 | 051/080 ] 損失 = 2.21101, 準確率 = 0.48068


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 052/080 ] 損失 = 0.20913, 準確率 = 0.92698


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 052/080 ] 損失 = 2.54175, 準確率 = 0.45199


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 053/080 ] 損失 = 0.19069, 準確率 = 0.94165


100%|██████████| 11/11 [00:05<00:00,  1.94it/s]


[ 驗證 | 053/080 ] 損失 = 2.34682, 準確率 = 0.45994


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 054/080 ] 損失 = 0.18281, 準確率 = 0.93750


100%|██████████| 11/11 [00:07<00:00,  1.48it/s]


[ 驗證 | 054/080 ] 損失 = 2.57653, 準確率 = 0.45739


100%|██████████| 49/49 [00:30<00:00,  1.59it/s]


[ 訓練 | 055/080 ] 損失 = 0.18510, 準確率 = 0.93208


100%|██████████| 11/11 [00:06<00:00,  1.69it/s]


[ 驗證 | 055/080 ] 損失 = 2.44167, 準確率 = 0.46619


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


[ 訓練 | 056/080 ] 損失 = 0.21007, 準確率 = 0.92761


100%|██████████| 11/11 [00:05<00:00,  1.91it/s]


[ 驗證 | 056/080 ] 損失 = 2.61631, 準確率 = 0.44631


100%|██████████| 49/49 [00:29<00:00,  1.68it/s]


[ 訓練 | 057/080 ] 損失 = 0.20874, 準確率 = 0.93048


100%|██████████| 11/11 [00:07<00:00,  1.49it/s]


[ 驗證 | 057/080 ] 損失 = 2.26556, 準確率 = 0.48778


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 058/080 ] 損失 = 0.18181, 準確率 = 0.94483


100%|██████████| 11/11 [00:05<00:00,  1.95it/s]


[ 驗證 | 058/080 ] 損失 = 2.42287, 準確率 = 0.48778


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 059/080 ] 損失 = 0.16887, 準確率 = 0.94834


100%|██████████| 11/11 [00:06<00:00,  1.57it/s]


[ 驗證 | 059/080 ] 損失 = 2.16649, 準確率 = 0.49148


100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


[ 訓練 | 060/080 ] 損失 = 0.16252, 準確率 = 0.94547


100%|██████████| 11/11 [00:06<00:00,  1.69it/s]


[ 驗證 | 060/080 ] 損失 = 2.41945, 準確率 = 0.47415


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


[ 訓練 | 061/080 ] 損失 = 0.15183, 準確率 = 0.95153


100%|██████████| 11/11 [00:05<00:00,  1.91it/s]


[ 驗證 | 061/080 ] 損失 = 2.44403, 準確率 = 0.46989


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 062/080 ] 損失 = 0.13471, 準確率 = 0.95695


100%|██████████| 11/11 [00:08<00:00,  1.23it/s]


[ 驗證 | 062/080 ] 損失 = 2.56495, 準確率 = 0.46136


100%|██████████| 49/49 [00:29<00:00,  1.69it/s]


[ 訓練 | 063/080 ] 損失 = 0.13661, 準確率 = 0.95568


100%|██████████| 11/11 [00:05<00:00,  1.87it/s]


[ 驗證 | 063/080 ] 損失 = 2.35478, 準確率 = 0.50114


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


[ 訓練 | 064/080 ] 損失 = 0.18292, 準確率 = 0.94133


100%|██████████| 11/11 [00:06<00:00,  1.79it/s]


[ 驗證 | 064/080 ] 損失 = 2.41711, 準確率 = 0.47415


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 065/080 ] 損失 = 0.15631, 準確率 = 0.94834


100%|██████████| 11/11 [00:06<00:00,  1.58it/s]


[ 驗證 | 065/080 ] 損失 = 2.37848, 準確率 = 0.46165


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 066/080 ] 損失 = 0.14883, 準確率 = 0.95057


100%|██████████| 11/11 [00:05<00:00,  1.97it/s]


[ 驗證 | 066/080 ] 損失 = 2.42228, 準確率 = 0.50142


100%|██████████| 49/49 [00:28<00:00,  1.73it/s]


[ 訓練 | 067/080 ] 損失 = 0.11888, 準確率 = 0.95727


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 067/080 ] 損失 = 2.16523, 準確率 = 0.51278


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 068/080 ] 損失 = 0.13477, 準確率 = 0.95504


100%|██████████| 11/11 [00:07<00:00,  1.53it/s]


[ 驗證 | 068/080 ] 損失 = 2.66152, 準確率 = 0.47472


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


[ 訓練 | 069/080 ] 損失 = 0.13351, 準確率 = 0.95568


100%|██████████| 11/11 [00:06<00:00,  1.71it/s]


[ 驗證 | 069/080 ] 損失 = 2.32957, 準確率 = 0.49290


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 070/080 ] 損失 = 0.12654, 準確率 = 0.95823


100%|██████████| 11/11 [00:07<00:00,  1.53it/s]


[ 驗證 | 070/080 ] 損失 = 2.76668, 準確率 = 0.45653


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 071/080 ] 損失 = 0.16026, 準確率 = 0.94866


100%|██████████| 11/11 [00:05<00:00,  1.90it/s]


[ 驗證 | 071/080 ] 損失 = 2.70786, 準確率 = 0.46818


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 072/080 ] 損失 = 0.14109, 準確率 = 0.95472


100%|██████████| 11/11 [00:07<00:00,  1.53it/s]


[ 驗證 | 072/080 ] 損失 = 2.91210, 準確率 = 0.42699


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 073/080 ] 損失 = 0.13675, 準確率 = 0.95568


100%|██████████| 11/11 [00:05<00:00,  1.97it/s]


[ 驗證 | 073/080 ] 損失 = 2.40220, 準確率 = 0.47415


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 074/080 ] 損失 = 0.13702, 準確率 = 0.95504


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 074/080 ] 損失 = 2.37679, 準確率 = 0.49631


100%|██████████| 49/49 [00:29<00:00,  1.64it/s]


[ 訓練 | 075/080 ] 損失 = 0.13935, 準確率 = 0.95344


100%|██████████| 11/11 [00:05<00:00,  1.93it/s]


[ 驗證 | 075/080 ] 損失 = 2.34393, 準確率 = 0.47955


100%|██████████| 49/49 [00:28<00:00,  1.73it/s]


[ 訓練 | 076/080 ] 損失 = 0.11323, 準確率 = 0.96779


100%|██████████| 11/11 [00:06<00:00,  1.65it/s]


[ 驗證 | 076/080 ] 損失 = 2.44018, 準確率 = 0.48949


100%|██████████| 49/49 [00:28<00:00,  1.73it/s]


[ 訓練 | 077/080 ] 損失 = 0.12442, 準確率 = 0.95695


100%|██████████| 11/11 [00:05<00:00,  1.87it/s]


[ 驗證 | 077/080 ] 損失 = 2.73219, 準確率 = 0.46847


100%|██████████| 49/49 [00:28<00:00,  1.71it/s]


[ 訓練 | 078/080 ] 損失 = 0.14864, 準確率 = 0.94834


100%|██████████| 11/11 [00:05<00:00,  1.88it/s]


[ 驗證 | 078/080 ] 損失 = 2.53313, 準確率 = 0.48750


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 079/080 ] 損失 = 0.12080, 準確率 = 0.95855


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 079/080 ] 損失 = 2.42373, 準確率 = 0.48722


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


[ 訓練 | 080/080 ] 損失 = 0.15401, 準確率 = 0.94834


100%|██████████| 11/11 [00:05<00:00,  1.97it/s]

[ 驗證 | 080/080 ] 損失 = 2.48382, 準確率 = 0.49006





In [7]:
# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model.eval()

# Initialize a list to store the predictions.
predictions = []

# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    # A batch consists of image data and corresponding labels.
    # But here the variable "labels" is useless since we do not have the ground-truth.
    # If printing out the labels, you will find that it is always 0.
    # This is because the wrapper (DatasetFolder) returns images and labels for each batch,
    # so we have to create fake labels to make it work normally.
    imgs, labels = batch

    # We don't need gradient in testing, and we don't even have labels to compute loss.
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(imgs.to(device))

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

100%|██████████| 53/53 [00:29<00:00,  1.80it/s]


In [8]:
# Save predictions into the file.
with open("predict.csv", "w") as f:

    # The first row must be "Id, Category"
    f.write("Id,Category\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")