In [1]:
# Download the dataset
!gdown --id '1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy' --output food-11.zip
!unzip -q food-11.zip

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder

# tqdm 顯示進度條的工具
from tqdm.auto import tqdm

Downloading...
From: https://drive.google.com/uc?id=1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy
To: /content/food-11.zip
100% 963M/963M [00:18<00:00, 52.1MB/s]


In [2]:


import torchvision.transforms as transforms

# 在訓練中進行數據增強
train_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomCrop((224, 224)),
    transforms.ToTensor(),
])

# 測試和驗證中只進行圖像轉換
test_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
])



In [9]:
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 64

# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [10]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # 使用 ResNet-50 模型作為基礎架構
        self.model = models.resnet50(pretrained=False)

        # 替換最後一層全連接層
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 11)

    def forward(self, x):
        x = self.model(x)
        return x

# 創建模型實例
model = Classifier()

# 定義損失函數
criterion = nn.CrossEntropyLoss()

# 定義優化器，例如使用 Adam 優化器
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)


In [11]:
def get_pseudo_labels(dataset, model, threshold=0.65):
    # This functions generates pseudo-labels of a dataset using given model.
    # It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    # You are NOT allowed to use any models trained on external data for pseudo-labeling.
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Construct a data loader.
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Make sure the model is in eval mode.
    model.eval()
    # Define softmax function.
    softmax = nn.Softmax(dim=-1)

    # Iterate over the dataset by batches.
    for batch in tqdm(data_loader):
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))

        # Obtain the probability distributions by applying softmax on logits.
        probs = softmax(logits)

        # ---------- TODO ----------
        # Filter the data and construct a new dataset.

    # # Turn off the eval mode.
    model.train()
    return dataset

In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

# 假設已經定義了以下函式和變數：
# get_pseudo_labels(unlabeled_set, model): 給定未標記數據和訓練好的模型，返回使用模型生成的伪標籤的數據集。
# train_set: 有標籤的訓練數據集。
# unlabeled_set: 未標記的數據集。
# valid_loader: 驗證數據的 DataLoader。
# batch_size: 批次大小。

# "cuda" 只有當有 GPU 可用時使用。
device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化模型並將其放置在指定的設備上。
model = Classifier().to(device)
model.device = device

# 對於分類任務，我們使用交叉熵作為性能衡量標準。
criterion = nn.CrossEntropyLoss()

# 初始化優化器，您可以自行微調一些超參數，如學習率。
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

# 訓練的 epoch 數量。
n_epochs = 80

# 是否進行半監督學習。
do_semi = False
from torch.optim.lr_scheduler import StepLR

# 在每個 epoch 後降低學習率
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

for epoch in range(n_epochs):
    # ---------- TODO ----------
    # 在每個 epoch 中，為半監督學習重新標記未標記數據集。
    # 然後，您可以將有標籤數據集和伪標籤數據集結合起來進行訓練。
    if do_semi:
        # 使用訓練好的模型為未標記數據獲取伪標籤。
        pseudo_set = get_pseudo_labels(unlabeled_set, model)

        # 構建新的數據集和數據加載器用於訓練。
        # 這僅在半監督學習中使用。
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

    # ---------- 訓練 ----------
    # 在訓練之前確保模型處於訓練模式。
    model.train()

    # 這些變量用於記錄訓練過程中的信息。
    train_loss = []
    train_accs = []

    # 逐批次迭代訓練數據集。
    for batch in tqdm(train_loader):

        # 一批次包含圖像數據和相應的標籤。
        imgs, labels = batch

        # 將數據傳遞到模型中（確保數據和模型在同一設備上）。
        logits = model(imgs.to(device))

        # 計算交叉熵損失。
        # 在計算交叉熵之前，我們不需要在 logits 上應用 softmax，因為它會自動完成。
        loss = criterion(logits, labels.to(device))

        # 清除先前步驟中參數中存儲的梯度。
        optimizer.zero_grad()

        # 計算參數的梯度。
        loss.backward()

        # 限制梯度的範數以實現穩定訓練。
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # 使用計算出的梯度更新參數。
        optimizer.step()

        # 計算當前批次的準確率。
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 記錄損失和準確率。
        train_loss.append(loss.item())
        train_accs.append(acc)

    # 訓練集的平均損失和準確率是記錄值的平均值。
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # 輸出訓練信息。
    print(f"[ 訓練 | {epoch + 1:03d}/{n_epochs:03d} ] 損失 = {train_loss:.5f}, 準確率 = {train_acc:.5f}")

    # ---------- 驗證 ----------
    # 確保模型處於評估模式，以便一些模塊如 dropout 被禁用並正常工作。
    model.eval()

    # 這些變量用於記錄驗證過程中的信息。
    valid_loss = []
    valid_accs = []

    # 逐批次迭代驗證數據集。
    for batch in tqdm(valid_loader):

        # 一批次包含圖像數據和相應的標籤。
        imgs, labels = batch

        # 在驗證過程中我們不需要梯度。
        # 使用 torch.no_grad() 可以加速前向過程。
        with torch.no_grad():
            logits = model(imgs.to(device))

        # 我們仍然可以計算損失（但沒有梯度）。
        loss = criterion(logits, labels.to(device))

        # 計算當前批次的準確率。
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 記錄損失和準確率。
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # 整個驗證集的平均損失和準確率是記錄值的平均值。
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # 輸出驗證信息。
    print(f"[ 驗證 | {epoch + 1:03d}/{n_epochs:03d} ] 損失 = {valid_loss:.5f}, 準確率 = {valid_acc:.5f}")


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


[ 訓練 | 001/080 ] 損失 = 2.48113, 準確率 = 0.13202


100%|██████████| 11/11 [00:07<00:00,  1.43it/s]


[ 驗證 | 001/080 ] 損失 = 2.68275, 準確率 = 0.09233


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 002/080 ] 損失 = 2.32436, 準確率 = 0.17219


100%|██████████| 11/11 [00:08<00:00,  1.27it/s]


[ 驗證 | 002/080 ] 損失 = 2.34622, 準確率 = 0.15227


100%|██████████| 49/49 [00:48<00:00,  1.00it/s]


[ 訓練 | 003/080 ] 損失 = 2.33626, 準確率 = 0.18878


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 003/080 ] 損失 = 2.35350, 準確率 = 0.17443


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 004/080 ] 損失 = 2.26576, 準確率 = 0.20631


100%|██████████| 11/11 [00:06<00:00,  1.61it/s]


[ 驗證 | 004/080 ] 損失 = 2.22867, 準確率 = 0.20426


100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


[ 訓練 | 005/080 ] 損失 = 2.12438, 準確率 = 0.27073


100%|██████████| 11/11 [00:08<00:00,  1.34it/s]


[ 驗證 | 005/080 ] 損失 = 2.21062, 準確率 = 0.26136


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 006/080 ] 損失 = 2.06283, 準確率 = 0.28125


100%|██████████| 11/11 [00:08<00:00,  1.28it/s]


[ 驗證 | 006/080 ] 損失 = 2.06839, 準確率 = 0.26847


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 007/080 ] 損失 = 2.01228, 準確率 = 0.29751


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 007/080 ] 損失 = 1.93021, 準確率 = 0.34063


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 008/080 ] 損失 = 1.89882, 準確率 = 0.33514


100%|██████████| 11/11 [00:07<00:00,  1.45it/s]


[ 驗證 | 008/080 ] 損失 = 2.28082, 準確率 = 0.25540


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 009/080 ] 損失 = 1.85563, 準確率 = 0.35172


100%|██████████| 11/11 [00:08<00:00,  1.30it/s]


[ 驗證 | 009/080 ] 損失 = 1.74104, 準確率 = 0.41477


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 010/080 ] 損失 = 1.79630, 準確率 = 0.37755


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 010/080 ] 損失 = 1.92940, 準確率 = 0.33693


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 011/080 ] 損失 = 1.77874, 準確率 = 0.38489


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 011/080 ] 損失 = 2.43972, 準確率 = 0.23665


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 012/080 ] 損失 = 1.73076, 準確率 = 0.40306


100%|██████████| 11/11 [00:08<00:00,  1.27it/s]


[ 驗證 | 012/080 ] 損失 = 2.19420, 準確率 = 0.29318


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 013/080 ] 損失 = 1.69476, 準確率 = 0.41135


100%|██████████| 11/11 [00:07<00:00,  1.40it/s]


[ 驗證 | 013/080 ] 損失 = 1.77113, 準確率 = 0.41506


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 014/080 ] 損失 = 1.67637, 準確率 = 0.41486


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 014/080 ] 損失 = 1.78785, 準確率 = 0.38636


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 015/080 ] 損失 = 1.58474, 準確率 = 0.44962


100%|██████████| 11/11 [00:07<00:00,  1.42it/s]


[ 驗證 | 015/080 ] 損失 = 1.72072, 準確率 = 0.41136


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 016/080 ] 損失 = 1.58105, 準確率 = 0.46269


100%|██████████| 11/11 [00:08<00:00,  1.37it/s]


[ 驗證 | 016/080 ] 損失 = 1.93012, 準確率 = 0.33182


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 017/080 ] 損失 = 1.53771, 準確率 = 0.47321


100%|██████████| 11/11 [00:07<00:00,  1.48it/s]


[ 驗證 | 017/080 ] 損失 = 1.87175, 準確率 = 0.37273


100%|██████████| 49/49 [00:51<00:00,  1.06s/it]


[ 訓練 | 018/080 ] 損失 = 1.46049, 準確率 = 0.49936


100%|██████████| 11/11 [00:07<00:00,  1.39it/s]


[ 驗證 | 018/080 ] 損失 = 2.04208, 準確率 = 0.35142


100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


[ 訓練 | 019/080 ] 損失 = 1.44578, 準確率 = 0.50925


100%|██████████| 11/11 [00:08<00:00,  1.33it/s]


[ 驗證 | 019/080 ] 損失 = 1.87167, 準確率 = 0.38381


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 020/080 ] 損失 = 1.41571, 準確率 = 0.50638


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 020/080 ] 損失 = 2.04903, 準確率 = 0.36335


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 021/080 ] 損失 = 1.41977, 準確率 = 0.50893


100%|██████████| 11/11 [00:06<00:00,  1.58it/s]


[ 驗證 | 021/080 ] 損失 = 1.81913, 準確率 = 0.43750


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 022/080 ] 損失 = 1.36286, 準確率 = 0.52583


100%|██████████| 11/11 [00:08<00:00,  1.27it/s]


[ 驗證 | 022/080 ] 損失 = 1.50573, 準確率 = 0.48778


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 023/080 ] 損失 = 1.33472, 準確率 = 0.54018


100%|██████████| 11/11 [00:07<00:00,  1.50it/s]


[ 驗證 | 023/080 ] 損失 = 1.73069, 準確率 = 0.43608


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 024/080 ] 損失 = 1.26643, 準確率 = 0.56505


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 024/080 ] 損失 = 1.66230, 準確率 = 0.46932


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 025/080 ] 損失 = 1.25575, 準確率 = 0.56696


100%|██████████| 11/11 [00:08<00:00,  1.24it/s]


[ 驗證 | 025/080 ] 損失 = 1.88031, 準確率 = 0.42188


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 026/080 ] 損失 = 1.23541, 準確率 = 0.57621


100%|██████████| 11/11 [00:07<00:00,  1.41it/s]


[ 驗證 | 026/080 ] 損失 = 1.92084, 準確率 = 0.38580


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 027/080 ] 損失 = 1.17124, 準確率 = 0.60204


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 027/080 ] 損失 = 1.68644, 準確率 = 0.45881


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 028/080 ] 損失 = 1.14591, 準確率 = 0.61798


100%|██████████| 11/11 [00:08<00:00,  1.31it/s]


[ 驗證 | 028/080 ] 損失 = 2.00312, 準確率 = 0.41250


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 029/080 ] 損失 = 1.15878, 準確率 = 0.60491


100%|██████████| 11/11 [00:07<00:00,  1.43it/s]


[ 驗證 | 029/080 ] 損失 = 1.59529, 準確率 = 0.48835


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 030/080 ] 損失 = 1.06706, 準確率 = 0.62915


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 030/080 ] 損失 = 1.63490, 準確率 = 0.52131


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


[ 訓練 | 031/080 ] 損失 = 1.05664, 準確率 = 0.64700


100%|██████████| 11/11 [00:08<00:00,  1.30it/s]


[ 驗證 | 031/080 ] 損失 = 1.76252, 準確率 = 0.47528


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 032/080 ] 損失 = 1.01768, 準確率 = 0.65083


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 032/080 ] 損失 = 1.65817, 準確率 = 0.48977


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 033/080 ] 損失 = 0.99780, 準確率 = 0.65689


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 033/080 ] 損失 = 1.59882, 準確率 = 0.52159


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


[ 訓練 | 034/080 ] 損失 = 1.02500, 準確率 = 0.65210


100%|██████████| 11/11 [00:08<00:00,  1.33it/s]


[ 驗證 | 034/080 ] 損失 = 1.78044, 準確率 = 0.45227


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 035/080 ] 損失 = 0.97358, 準確率 = 0.66263


100%|██████████| 11/11 [00:07<00:00,  1.40it/s]


[ 驗證 | 035/080 ] 損失 = 2.17561, 準確率 = 0.39517


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 036/080 ] 損失 = 0.91685, 準確率 = 0.69005


100%|██████████| 11/11 [00:09<00:00,  1.20it/s]


[ 驗證 | 036/080 ] 損失 = 1.70140, 準確率 = 0.49148


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 037/080 ] 損失 = 0.93306, 準確率 = 0.67953


100%|██████████| 11/11 [00:08<00:00,  1.26it/s]


[ 驗證 | 037/080 ] 損失 = 2.02262, 準確率 = 0.45994


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 038/080 ] 損失 = 0.90888, 準確率 = 0.68814


100%|██████████| 11/11 [00:06<00:00,  1.61it/s]


[ 驗證 | 038/080 ] 損失 = 1.99144, 準確率 = 0.42472


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 039/080 ] 損失 = 0.85110, 準確率 = 0.71556


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 039/080 ] 損失 = 1.64510, 準確率 = 0.50455


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 040/080 ] 損失 = 0.84684, 準確率 = 0.71046


100%|██████████| 11/11 [00:08<00:00,  1.26it/s]


[ 驗證 | 040/080 ] 損失 = 1.77301, 準確率 = 0.48778


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 041/080 ] 損失 = 0.78634, 準確率 = 0.72513


100%|██████████| 11/11 [00:07<00:00,  1.47it/s]


[ 驗證 | 041/080 ] 損失 = 1.34578, 準確率 = 0.55256


100%|██████████| 49/49 [00:51<00:00,  1.06s/it]


[ 訓練 | 042/080 ] 損失 = 0.80772, 準確率 = 0.72577


100%|██████████| 11/11 [00:07<00:00,  1.53it/s]


[ 驗證 | 042/080 ] 損失 = 1.46016, 準確率 = 0.55852


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 043/080 ] 損失 = 0.74234, 準確率 = 0.75032


100%|██████████| 11/11 [00:08<00:00,  1.24it/s]


[ 驗證 | 043/080 ] 損失 = 1.54101, 準確率 = 0.56648


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 044/080 ] 損失 = 0.76858, 準確率 = 0.73724


100%|██████████| 11/11 [00:07<00:00,  1.53it/s]


[ 驗證 | 044/080 ] 損失 = 1.50731, 準確率 = 0.56307


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 045/080 ] 損失 = 0.73734, 準確率 = 0.75128


100%|██████████| 11/11 [00:06<00:00,  1.58it/s]


[ 驗證 | 045/080 ] 損失 = 1.72263, 準確率 = 0.52472


100%|██████████| 49/49 [00:52<00:00,  1.06s/it]


[ 訓練 | 046/080 ] 損失 = 0.68824, 準確率 = 0.76435


100%|██████████| 11/11 [00:08<00:00,  1.36it/s]


[ 驗證 | 046/080 ] 損失 = 1.46664, 準確率 = 0.56165


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 047/080 ] 損失 = 0.65022, 準確率 = 0.78412


100%|██████████| 11/11 [00:06<00:00,  1.57it/s]


[ 驗證 | 047/080 ] 損失 = 1.69470, 準確率 = 0.51506


100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


[ 訓練 | 048/080 ] 損失 = 0.64694, 準確率 = 0.78316


100%|██████████| 11/11 [00:08<00:00,  1.31it/s]


[ 驗證 | 048/080 ] 損失 = 2.07616, 準確率 = 0.48182


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 049/080 ] 損失 = 0.62691, 準確率 = 0.79050


100%|██████████| 11/11 [00:07<00:00,  1.48it/s]


[ 驗證 | 049/080 ] 損失 = 1.99441, 準確率 = 0.49432


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 050/080 ] 損失 = 0.62304, 準確率 = 0.78731


100%|██████████| 11/11 [00:06<00:00,  1.57it/s]


[ 驗證 | 050/080 ] 損失 = 1.73552, 準確率 = 0.54460


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 051/080 ] 損失 = 0.57671, 準確率 = 0.79719


100%|██████████| 11/11 [00:07<00:00,  1.38it/s]


[ 驗證 | 051/080 ] 損失 = 2.17519, 準確率 = 0.51847


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 052/080 ] 損失 = 0.63250, 準確率 = 0.78540


100%|██████████| 11/11 [00:08<00:00,  1.31it/s]


[ 驗證 | 052/080 ] 損失 = 1.64608, 準確率 = 0.53636


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 053/080 ] 損失 = 0.58190, 準確率 = 0.79560


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 053/080 ] 損失 = 1.77824, 準確率 = 0.54233


100%|██████████| 49/49 [00:52<00:00,  1.06s/it]


[ 訓練 | 054/080 ] 損失 = 0.54571, 準確率 = 0.81186


100%|██████████| 11/11 [00:06<00:00,  1.58it/s]


[ 驗證 | 054/080 ] 損失 = 2.72257, 準確率 = 0.41193


100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


[ 訓練 | 055/080 ] 損失 = 0.54272, 準確率 = 0.81027


100%|██████████| 11/11 [00:07<00:00,  1.40it/s]


[ 驗證 | 055/080 ] 損失 = 1.57162, 準確率 = 0.55881


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 056/080 ] 損失 = 0.51713, 準確率 = 0.82812


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 056/080 ] 損失 = 2.26334, 準確率 = 0.46506


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 057/080 ] 損失 = 0.48858, 準確率 = 0.83195


100%|██████████| 11/11 [00:08<00:00,  1.33it/s]


[ 驗證 | 057/080 ] 損失 = 2.17572, 準確率 = 0.50142


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 058/080 ] 損失 = 0.49168, 準確率 = 0.83865


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 058/080 ] 損失 = 1.64264, 準確率 = 0.58324


100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


[ 訓練 | 059/080 ] 損失 = 0.46932, 準確率 = 0.83737


100%|██████████| 11/11 [00:08<00:00,  1.37it/s]


[ 驗證 | 059/080 ] 損失 = 1.73420, 準確率 = 0.58551


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 060/080 ] 損失 = 0.46245, 準確率 = 0.84566


100%|██████████| 11/11 [00:08<00:00,  1.25it/s]


[ 驗證 | 060/080 ] 損失 = 1.76595, 準確率 = 0.56733


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 061/080 ] 損失 = 0.42635, 準確率 = 0.86193


100%|██████████| 11/11 [00:07<00:00,  1.52it/s]


[ 驗證 | 061/080 ] 損失 = 2.18259, 準確率 = 0.52358


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 062/080 ] 損失 = 0.40142, 準確率 = 0.86448


100%|██████████| 11/11 [00:08<00:00,  1.29it/s]


[ 驗證 | 062/080 ] 損失 = 1.81268, 準確率 = 0.54347


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 063/080 ] 損失 = 0.36151, 準確率 = 0.87117


100%|██████████| 11/11 [00:08<00:00,  1.24it/s]


[ 驗證 | 063/080 ] 損失 = 2.08886, 準確率 = 0.52955


100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


[ 訓練 | 064/080 ] 損失 = 0.40051, 準確率 = 0.86798


100%|██████████| 11/11 [00:07<00:00,  1.51it/s]


[ 驗證 | 064/080 ] 損失 = 1.74831, 準確率 = 0.59062


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 065/080 ] 損失 = 0.40828, 準確率 = 0.86129


100%|██████████| 11/11 [00:08<00:00,  1.30it/s]


[ 驗證 | 065/080 ] 損失 = 1.68214, 準確率 = 0.55966


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 066/080 ] 損失 = 0.39866, 準確率 = 0.86320


100%|██████████| 11/11 [00:07<00:00,  1.44it/s]


[ 驗證 | 066/080 ] 損失 = 1.96380, 準確率 = 0.55881


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 067/080 ] 損失 = 0.38710, 準確率 = 0.87149


100%|██████████| 11/11 [00:06<00:00,  1.60it/s]


[ 驗證 | 067/080 ] 損失 = 1.70489, 準確率 = 0.58864


100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


[ 訓練 | 068/080 ] 損失 = 0.30688, 準確率 = 0.88871


100%|██████████| 11/11 [00:08<00:00,  1.25it/s]


[ 驗證 | 068/080 ] 損失 = 1.87826, 準確率 = 0.58466


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 069/080 ] 損失 = 0.39183, 準確率 = 0.87372


100%|██████████| 11/11 [00:07<00:00,  1.43it/s]


[ 驗證 | 069/080 ] 損失 = 2.52740, 準確率 = 0.49148


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 070/080 ] 損失 = 0.31267, 準確率 = 0.89381


100%|██████████| 11/11 [00:06<00:00,  1.59it/s]


[ 驗證 | 070/080 ] 損失 = 1.67802, 準確率 = 0.62074


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 071/080 ] 損失 = 0.30953, 準確率 = 0.89318


100%|██████████| 11/11 [00:08<00:00,  1.34it/s]


[ 驗證 | 071/080 ] 損失 = 1.83176, 準確率 = 0.58324


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 072/080 ] 損失 = 0.32656, 準確率 = 0.88903


100%|██████████| 11/11 [00:08<00:00,  1.28it/s]


[ 驗證 | 072/080 ] 損失 = 2.81905, 準確率 = 0.44659


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 073/080 ] 損失 = 0.31415, 準確率 = 0.88776


100%|██████████| 11/11 [00:06<00:00,  1.58it/s]


[ 驗證 | 073/080 ] 損失 = 1.62613, 準確率 = 0.62727


100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


[ 訓練 | 074/080 ] 損失 = 0.33880, 準確率 = 0.88648


100%|██████████| 11/11 [00:08<00:00,  1.33it/s]


[ 驗證 | 074/080 ] 損失 = 2.13617, 準確率 = 0.53267


100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


[ 訓練 | 075/080 ] 損失 = 0.31745, 準確率 = 0.89126


100%|██████████| 11/11 [00:07<00:00,  1.42it/s]


[ 驗證 | 075/080 ] 損失 = 1.92879, 準確率 = 0.55739


100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


[ 訓練 | 076/080 ] 損失 = 0.26815, 準確率 = 0.90689


100%|██████████| 11/11 [00:07<00:00,  1.54it/s]


[ 驗證 | 076/080 ] 損失 = 2.42726, 準確率 = 0.47557


100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


[ 訓練 | 077/080 ] 損失 = 0.30748, 準確率 = 0.89445


100%|██████████| 11/11 [00:08<00:00,  1.31it/s]


[ 驗證 | 077/080 ] 損失 = 3.03413, 準確率 = 0.42699


100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


[ 訓練 | 078/080 ] 損失 = 0.26149, 準確率 = 0.91135


100%|██████████| 11/11 [00:07<00:00,  1.55it/s]


[ 驗證 | 078/080 ] 損失 = 1.53369, 準確率 = 0.63750


100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


[ 訓練 | 079/080 ] 損失 = 0.26073, 準確率 = 0.91295


100%|██████████| 11/11 [00:08<00:00,  1.29it/s]


[ 驗證 | 079/080 ] 損失 = 1.78843, 準確率 = 0.61108


100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


[ 訓練 | 080/080 ] 損失 = 0.26892, 準確率 = 0.91167


100%|██████████| 11/11 [00:08<00:00,  1.37it/s]

[ 驗證 | 080/080 ] 損失 = 1.73879, 準確率 = 0.60824





In [7]:
# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model.eval()

# Initialize a list to store the predictions.
predictions = []

# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    # A batch consists of image data and corresponding labels.
    # But here the variable "labels" is useless since we do not have the ground-truth.
    # If printing out the labels, you will find that it is always 0.
    # This is because the wrapper (DatasetFolder) returns images and labels for each batch,
    # so we have to create fake labels to make it work normally.
    imgs, labels = batch

    # We don't need gradient in testing, and we don't even have labels to compute loss.
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(imgs.to(device))

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

100%|██████████| 27/27 [00:39<00:00,  1.45s/it]


In [8]:
# Save predictions into the file.
with open("predict.csv", "w") as f:

    # The first row must be "Id, Category"
    f.write("Id,Category\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")