# Visual Transformer with Linformer

Training Visual Transformer on *Dogs vs Cats Data*

* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/
* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention

In [1]:
!pip -q install vit_pytorch linformer

## Import Libraries

In [2]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT


In [3]:
print(f"Torch: {torch.__version__}")

Torch: 2.7.0+cu118


In [4]:
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


Torch: 2.7.0+cu118
CUDA available: True


In [5]:
# Training settings
batch_size = 64
epochs = 100
lr = 3e-5
gamma = 0.7
seed = 42

In [6]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [7]:
device = 'cuda'

## Load Data

In [17]:
# os.makedirs('data', exist_ok=True)
root_path = r"E:\Project\RemoteSensing\Datasets" 

In [19]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

# 使用 ImageFolder 加载数据
train_data = datasets.ImageFolder(root=os.path.join(root_path, 'train'), transform=train_transforms)
valid_data   = datasets.ImageFolder(root=os.path.join(root_path, 'val'), transform=val_transforms)
test_data  = datasets.ImageFolder(root=os.path.join(root_path, 'test'), transform=val_transforms)

# DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
valid_loader   = DataLoader(valid_data, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

In [20]:
print(f"Train: {len(train_data)} images, {len(train_loader)} batches")
print(f"Val:   {len(valid_data)} images, {len(valid_loader)} batches")
print(f"Test:  {len(test_data)} images, {len(test_loader)} batches")
print(f"Classes: {train_data.classes}")

Train: 1259 images, 20 batches
Val:   155 images, 3 batches
Test:  162 images, 3 batches
Classes: ['Beach', 'Bridge', 'Pond', 'Port', 'River']


In [21]:
print(len(train_data), len(train_loader))

1259 20


In [22]:
print(len(valid_data), len(valid_loader))

155 3


## Efficient Attention

### Linformer

In [23]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [24]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=5,
    transformer=efficient_transformer,
    channels=3,
).to(device)

### Training

In [25]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [26]:
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_acc = 0

In [27]:
def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
    """加载检查点函数"""
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_acc = checkpoint['best_val_acc']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch, best_val_acc
    else:
        return 0, 0

In [28]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    # 在这里添加检查点保存逻辑
    # 1. 保存最佳模型
    if epoch_val_accuracy > best_val_acc:
        best_val_acc = epoch_val_accuracy
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'best_model.pth'))
        print(f"Saved best model with validation accuracy: {best_val_acc:.4f}")
    
    # 2. 定期保存检查点（每10个epoch）
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"Saved regular checkpoint at epoch {epoch+1}")

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )


  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.3364
Epoch : 1 - loss : 1.6096 - acc: 0.1902 - val_loss : 1.5582 - val_acc: 0.3364



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.4720
Epoch : 2 - loss : 1.5558 - acc: 0.3026 - val_loss : 1.4855 - val_acc: 0.4720



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 3 - loss : 1.5026 - acc: 0.3397 - val_loss : 1.3923 - val_acc: 0.4616



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 4 - loss : 1.4474 - acc: 0.3612 - val_loss : 1.4478 - val_acc: 0.2903



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 5 - loss : 1.4206 - acc: 0.3725 - val_loss : 1.3726 - val_acc: 0.4176



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 6 - loss : 1.4096 - acc: 0.3818 - val_loss : 1.3416 - val_acc: 0.4475



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.4728
Epoch : 7 - loss : 1.3998 - acc: 0.3838 - val_loss : 1.3172 - val_acc: 0.4728



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 8 - loss : 1.3883 - acc: 0.3975 - val_loss : 1.3187 - val_acc: 0.4689



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 9 - loss : 1.3857 - acc: 0.3916 - val_loss : 1.3828 - val_acc: 0.3605



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.4873
Saved regular checkpoint at epoch 10
Epoch : 10 - loss : 1.3712 - acc: 0.3787 - val_loss : 1.2914 - val_acc: 0.4873



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.4996
Epoch : 11 - loss : 1.3581 - acc: 0.4108 - val_loss : 1.2610 - val_acc: 0.4996



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.5575
Epoch : 12 - loss : 1.3320 - acc: 0.4393 - val_loss : 1.2313 - val_acc: 0.5575



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6173
Epoch : 13 - loss : 1.3156 - acc: 0.4428 - val_loss : 1.1970 - val_acc: 0.6173



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 14 - loss : 1.2843 - acc: 0.4635 - val_loss : 1.1671 - val_acc: 0.5926



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 15 - loss : 1.2941 - acc: 0.4643 - val_loss : 1.2089 - val_acc: 0.5133



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 16 - loss : 1.2672 - acc: 0.4835 - val_loss : 1.1733 - val_acc: 0.5478



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 17 - loss : 1.2431 - acc: 0.4896 - val_loss : 1.1849 - val_acc: 0.5089



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 18 - loss : 1.2271 - acc: 0.4889 - val_loss : 1.1523 - val_acc: 0.4919



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6285
Epoch : 19 - loss : 1.1515 - acc: 0.5401 - val_loss : 0.9748 - val_acc: 0.6285



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 20
Epoch : 20 - loss : 1.1655 - acc: 0.5295 - val_loss : 1.0431 - val_acc: 0.5610



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 21 - loss : 1.1212 - acc: 0.5493 - val_loss : 0.9723 - val_acc: 0.6194



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6512
Epoch : 22 - loss : 1.0722 - acc: 0.5599 - val_loss : 0.8937 - val_acc: 0.6512



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 23 - loss : 1.1018 - acc: 0.5357 - val_loss : 0.9832 - val_acc: 0.5772



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 24 - loss : 1.0898 - acc: 0.5482 - val_loss : 0.9382 - val_acc: 0.6123



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 25 - loss : 1.0934 - acc: 0.5708 - val_loss : 0.8845 - val_acc: 0.6447



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 26 - loss : 1.0602 - acc: 0.5786 - val_loss : 0.9589 - val_acc: 0.6103



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6545
Epoch : 27 - loss : 1.0515 - acc: 0.5732 - val_loss : 0.9050 - val_acc: 0.6545



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 28 - loss : 1.0428 - acc: 0.5830 - val_loss : 1.0373 - val_acc: 0.5239



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 29 - loss : 1.0326 - acc: 0.5755 - val_loss : 0.8535 - val_acc: 0.6460



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6584
Saved regular checkpoint at epoch 30
Epoch : 30 - loss : 1.0538 - acc: 0.5787 - val_loss : 0.8502 - val_acc: 0.6584



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7045
Epoch : 31 - loss : 1.0323 - acc: 0.5806 - val_loss : 0.8123 - val_acc: 0.7045



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 32 - loss : 1.0336 - acc: 0.5752 - val_loss : 0.8236 - val_acc: 0.6518



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 33 - loss : 1.0639 - acc: 0.5787 - val_loss : 0.8450 - val_acc: 0.6655



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 34 - loss : 1.0088 - acc: 0.5775 - val_loss : 0.8336 - val_acc: 0.6570



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 35 - loss : 1.0435 - acc: 0.5681 - val_loss : 0.8251 - val_acc: 0.6375



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 36 - loss : 1.0661 - acc: 0.5619 - val_loss : 0.8669 - val_acc: 0.6480



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 37 - loss : 1.0648 - acc: 0.5779 - val_loss : 0.8904 - val_acc: 0.6474



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 38 - loss : 1.0188 - acc: 0.5939 - val_loss : 0.8469 - val_acc: 0.6356



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 39 - loss : 1.0232 - acc: 0.5807 - val_loss : 0.8192 - val_acc: 0.6570



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 40
Epoch : 40 - loss : 0.9895 - acc: 0.6040 - val_loss : 0.8202 - val_acc: 0.6726



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 41 - loss : 1.0392 - acc: 0.5807 - val_loss : 0.8865 - val_acc: 0.6161



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 42 - loss : 1.0381 - acc: 0.5826 - val_loss : 0.9594 - val_acc: 0.6331



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 43 - loss : 0.9998 - acc: 0.6013 - val_loss : 0.8189 - val_acc: 0.6655



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 44 - loss : 0.9618 - acc: 0.6290 - val_loss : 0.8254 - val_acc: 0.6721



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 45 - loss : 0.9745 - acc: 0.6009 - val_loss : 0.8522 - val_acc: 0.6337



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 46 - loss : 0.9782 - acc: 0.6076 - val_loss : 0.8046 - val_acc: 0.6759



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 47 - loss : 0.9880 - acc: 0.6017 - val_loss : 0.8066 - val_acc: 0.6343



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 48 - loss : 0.9970 - acc: 0.5896 - val_loss : 0.8079 - val_acc: 0.6831



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 49 - loss : 0.9969 - acc: 0.6142 - val_loss : 0.8068 - val_acc: 0.6726



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 50
Epoch : 50 - loss : 0.9940 - acc: 0.6174 - val_loss : 0.8131 - val_acc: 0.6564



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 51 - loss : 0.9738 - acc: 0.6279 - val_loss : 0.7999 - val_acc: 0.6707



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 52 - loss : 1.0023 - acc: 0.5939 - val_loss : 0.7944 - val_acc: 0.6707



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 53 - loss : 0.9886 - acc: 0.6115 - val_loss : 0.7822 - val_acc: 0.6779



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 54 - loss : 0.9710 - acc: 0.6150 - val_loss : 0.8091 - val_acc: 0.6532



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 55 - loss : 0.9408 - acc: 0.6447 - val_loss : 0.7644 - val_acc: 0.6902



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 56 - loss : 0.9741 - acc: 0.6099 - val_loss : 0.8022 - val_acc: 0.6707



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 57 - loss : 0.9789 - acc: 0.6044 - val_loss : 0.8043 - val_acc: 0.6916



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 58 - loss : 0.9564 - acc: 0.6150 - val_loss : 0.7695 - val_acc: 0.6850



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 59 - loss : 0.9943 - acc: 0.6084 - val_loss : 0.8199 - val_acc: 0.6929



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 60
Epoch : 60 - loss : 0.9575 - acc: 0.6146 - val_loss : 0.7770 - val_acc: 0.6532



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 61 - loss : 0.9542 - acc: 0.6232 - val_loss : 0.7988 - val_acc: 0.6825



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 62 - loss : 0.9855 - acc: 0.6170 - val_loss : 0.8624 - val_acc: 0.6507



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 63 - loss : 0.9564 - acc: 0.6161 - val_loss : 0.7946 - val_acc: 0.6649



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 64 - loss : 0.9649 - acc: 0.6263 - val_loss : 0.7807 - val_acc: 0.6773



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 65 - loss : 0.9449 - acc: 0.6345 - val_loss : 0.7849 - val_acc: 0.6792



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 66 - loss : 0.9680 - acc: 0.6185 - val_loss : 0.7973 - val_acc: 0.6786



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 67 - loss : 0.9421 - acc: 0.6314 - val_loss : 0.8110 - val_acc: 0.6701



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7052
Epoch : 68 - loss : 0.9358 - acc: 0.6303 - val_loss : 0.7759 - val_acc: 0.7052



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 69 - loss : 0.9267 - acc: 0.6563 - val_loss : 0.7519 - val_acc: 0.6896



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7124
Saved regular checkpoint at epoch 70
Epoch : 70 - loss : 0.9367 - acc: 0.6404 - val_loss : 0.7785 - val_acc: 0.7124



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 71 - loss : 0.9271 - acc: 0.6361 - val_loss : 0.7910 - val_acc: 0.6721



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 72 - loss : 0.9701 - acc: 0.6228 - val_loss : 0.7969 - val_acc: 0.6929



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 73 - loss : 0.9555 - acc: 0.6264 - val_loss : 0.8633 - val_acc: 0.6422



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 74 - loss : 0.9213 - acc: 0.6341 - val_loss : 0.7696 - val_acc: 0.6844



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 75 - loss : 0.9355 - acc: 0.6385 - val_loss : 0.7797 - val_acc: 0.6759



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 76 - loss : 0.9356 - acc: 0.6283 - val_loss : 0.7856 - val_acc: 0.6753



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 77 - loss : 0.9250 - acc: 0.6470 - val_loss : 0.7451 - val_acc: 0.6916



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 78 - loss : 0.9219 - acc: 0.6510 - val_loss : 0.7653 - val_acc: 0.7052



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 79 - loss : 0.9564 - acc: 0.6264 - val_loss : 0.7698 - val_acc: 0.6962



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 80
Epoch : 80 - loss : 0.8805 - acc: 0.6689 - val_loss : 0.7208 - val_acc: 0.6935



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 81 - loss : 0.9142 - acc: 0.6455 - val_loss : 0.7704 - val_acc: 0.6611



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 82 - loss : 0.9208 - acc: 0.6428 - val_loss : 0.7603 - val_acc: 0.6825



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 83 - loss : 0.9096 - acc: 0.6545 - val_loss : 0.7633 - val_acc: 0.6682



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 84 - loss : 0.8932 - acc: 0.6446 - val_loss : 0.7383 - val_acc: 0.6773



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 85 - loss : 0.8902 - acc: 0.6709 - val_loss : 0.7535 - val_acc: 0.6863



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 86 - loss : 0.8828 - acc: 0.6603 - val_loss : 0.7473 - val_acc: 0.6877



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 87 - loss : 0.8987 - acc: 0.6544 - val_loss : 0.7429 - val_acc: 0.6740



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 88 - loss : 0.9022 - acc: 0.6560 - val_loss : 0.7309 - val_acc: 0.6844



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 89 - loss : 0.9135 - acc: 0.6486 - val_loss : 0.7313 - val_acc: 0.7072



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7137
Saved regular checkpoint at epoch 90
Epoch : 90 - loss : 0.8805 - acc: 0.6685 - val_loss : 0.7483 - val_acc: 0.7137



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 91 - loss : 0.8664 - acc: 0.6747 - val_loss : 0.7140 - val_acc: 0.7085



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 92 - loss : 0.8659 - acc: 0.6606 - val_loss : 0.7190 - val_acc: 0.7039



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 93 - loss : 0.8953 - acc: 0.6623 - val_loss : 0.7191 - val_acc: 0.6948



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 94 - loss : 0.8921 - acc: 0.6623 - val_loss : 0.7176 - val_acc: 0.7014



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 95 - loss : 0.9006 - acc: 0.6447 - val_loss : 0.7205 - val_acc: 0.7124



  0%|          | 0/20 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7494
Epoch : 96 - loss : 0.8620 - acc: 0.6705 - val_loss : 0.6770 - val_acc: 0.7494



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 97 - loss : 0.8537 - acc: 0.6704 - val_loss : 0.6951 - val_acc: 0.7072



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 98 - loss : 0.8664 - acc: 0.6696 - val_loss : 0.7352 - val_acc: 0.7170



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch : 99 - loss : 0.9092 - acc: 0.6509 - val_loss : 0.7041 - val_acc: 0.7195



  0%|          | 0/20 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 100
Epoch : 100 - loss : 0.8610 - acc: 0.6786 - val_loss : 0.6925 - val_acc: 0.7195



In [29]:
# 加载最佳模型并进行测试
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
if os.path.exists(best_model_path):
    # 加载最佳模型
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model with validation accuracy: {checkpoint['best_val_acc']:.4f}")
    
    # 初始化每个类别的统计
    class_correct = {classname: 0 for classname in train_data.classes}
    class_total = {classname: 0 for classname in train_data.classes}
    
    # 在测试集上评估
    model.eval()
    test_loss = 0
    test_accuracy = 0
    
    with torch.no_grad():
        for data, labels in tqdm(test_loader, desc="Testing"):
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            
            # 计算总体准确率
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).float()
            test_accuracy += correct.sum() / len(test_loader.dataset)
            test_loss += loss.item() / len(test_loader)
            
            # 计算每个类别的准确率
            for label, prediction, correct_pred in zip(labels, predicted, correct):
                label_class = train_data.classes[label]
                class_correct[label_class] += correct_pred.item()
                class_total[label_class] += 1
    
    # 打印总体结果
    print(f"\nTest Results:")
    print(f"Average Loss: {test_loss:.4f}")
    print(f"Overall Accuracy: {test_accuracy:.4f}")
    
    # 打印每个类别的准确率
    print("\nPer-class Accuracy:")
    for classname in train_data.classes:
        accuracy = class_correct[classname] / class_total[classname]
        print(f"{classname:}: {accuracy:.4f} ({int(class_correct[classname])}/{class_total[classname]})")
else:
    print("No best model checkpoint found!")

Loaded best model with validation accuracy: 0.7494


Testing:   0%|          | 0/3 [00:00<?, ?it/s]


Test Results:
Average Loss: 0.6606
Overall Accuracy: 0.7531

Per-class Accuracy:
Beach: 0.9688 (31/32)
Bridge: 0.5000 (15/30)
Pond: 0.7143 (25/35)
Port: 0.6774 (21/31)
River: 0.8824 (30/34)
