# Visual Transformer with Linformer

Training Visual Transformer on *Dogs vs Cats Data*

* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/
* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention

In [20]:
!pip -q install vit_pytorch linformer

## Import Libraries

In [21]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import seaborn as sns
import time

from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,cohen_kappa_score

from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT


In [22]:
print(f"Torch: {torch.__version__}")

Torch: 2.5.1


In [23]:
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


Torch: 2.5.1
CUDA available: True


In [24]:
# Training settings
batch_size = 64
epochs = 200
lr = 3e-5
gamma = 0.7
seed = 42

In [25]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [26]:
device = 'cuda'

## Load Data

In [27]:
# os.makedirs('data', exist_ok=True)
root_path = r"/scratch/tshu2/jyu197/Datasets" 

In [28]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

# 使用 ImageFolder 加载数据
train_data = datasets.ImageFolder(root=os.path.join(root_path, 'train'), transform=train_transforms)
valid_data   = datasets.ImageFolder(root=os.path.join(root_path, 'val'), transform=val_transforms)
test_data  = datasets.ImageFolder(root=os.path.join(root_path, 'test'), transform=val_transforms)

# DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
valid_loader   = DataLoader(valid_data, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

In [29]:
print(f"Train: {len(train_data)} images, {len(train_loader)} batches")
print(f"Val:   {len(valid_data)} images, {len(valid_loader)} batches")
print(f"Test:  {len(test_data)} images, {len(test_loader)} batches")
print(f"Classes: {train_data.classes}")

Train: 5041 images, 79 batches
Val:   1263 images, 20 batches
Test:  359 images, 6 batches
Classes: ['Beach', 'Bridge', 'Pond', 'Port', 'River']


In [30]:
print(len(train_data), len(train_loader))

5041 79


In [31]:
print(len(valid_data), len(valid_loader))

1263 20


## Efficient Attention

### Linformer

In [32]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [33]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=5,
    transformer=efficient_transformer,
    channels=3,
).to(device)

### Training

In [34]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [35]:
checkpoint_dir = '/scratch/tshu2/jyu197/remotesensing/examples/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_acc = 0

In [36]:
history = {
    "train_loss":[],
    "val_loss":[],
    "train_acc":[],
    "val_acc":[]
}

In [37]:
def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
    """加载检查点函数"""
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_acc = checkpoint['best_val_acc']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch, best_val_acc
    else:
        return 0, 0

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    # 在这里添加检查点保存逻辑
    # 1. 保存最佳模型
    if epoch_val_accuracy > best_val_acc:
        best_val_acc = epoch_val_accuracy
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'best_model.pth'))
        print(f"Saved best model with validation accuracy: {best_val_acc:.4f}")
    
    # 2. 定期保存检查点（每10个epoch）
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"Saved regular checkpoint at epoch {epoch+1}")

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    history["train_loss"].append(epoch_loss.detach().cpu().item())
    history["val_loss"].append(epoch_val_loss.detach().cpu().item())
    history["train_acc"].append(epoch_accuracy.detach().cpu().item() * 100)
    history["val_acc"].append(epoch_val_accuracy.detach().cpu().item() * 100)


  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.3511
Epoch : 1 - loss : 1.5457 - acc: 0.3005 - val_loss : 1.4493 - val_acc: 0.3511



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 2 - loss : 1.4093 - acc: 0.3881 - val_loss : 1.4101 - val_acc: 0.3486



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.4838
Epoch : 3 - loss : 1.3558 - acc: 0.4301 - val_loss : 1.2599 - val_acc: 0.4838



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.5708
Epoch : 4 - loss : 1.2504 - acc: 0.4865 - val_loss : 1.0849 - val_acc: 0.5708



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6127
Epoch : 5 - loss : 1.1107 - acc: 0.5497 - val_loss : 0.9655 - val_acc: 0.6127



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6159
Epoch : 6 - loss : 1.1030 - acc: 0.5456 - val_loss : 0.9614 - val_acc: 0.6159



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6255
Epoch : 7 - loss : 1.0528 - acc: 0.5734 - val_loss : 0.9175 - val_acc: 0.6255



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6433
Epoch : 8 - loss : 1.0498 - acc: 0.5751 - val_loss : 0.8929 - val_acc: 0.6433



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6732
Epoch : 9 - loss : 1.0120 - acc: 0.5984 - val_loss : 0.8831 - val_acc: 0.6732



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 10
Epoch : 10 - loss : 1.0120 - acc: 0.5939 - val_loss : 0.8494 - val_acc: 0.6669



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6752
Epoch : 11 - loss : 1.0021 - acc: 0.6022 - val_loss : 0.8522 - val_acc: 0.6752



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6821
Epoch : 12 - loss : 0.9957 - acc: 0.6089 - val_loss : 0.8428 - val_acc: 0.6821



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6925
Epoch : 13 - loss : 0.9808 - acc: 0.6182 - val_loss : 0.8020 - val_acc: 0.6925



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 14 - loss : 0.9634 - acc: 0.6270 - val_loss : 0.8549 - val_acc: 0.6677



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7070
Epoch : 15 - loss : 0.9654 - acc: 0.6242 - val_loss : 0.7955 - val_acc: 0.7070



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7198
Epoch : 16 - loss : 0.9522 - acc: 0.6341 - val_loss : 0.7952 - val_acc: 0.7198



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7364
Epoch : 17 - loss : 0.9201 - acc: 0.6488 - val_loss : 0.7459 - val_acc: 0.7364



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7488
Epoch : 18 - loss : 0.9077 - acc: 0.6464 - val_loss : 0.7169 - val_acc: 0.7488



  0%|          | 0/79 [00:00<?, ?it/s]

In [None]:
sns.set_theme(style="whitegrid")

def plot_curves(hist, out="curves.png"):
    train_loss = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["train_loss"]
    ]
    val_loss = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["val_loss"]
    ]
    train_acc = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["train_acc"]
    ]
    val_acc = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["val_acc"]
    ]

    epochs = range(1, len(train_loss) + 1)

  
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(epochs, train_loss, label="Train Loss")
    ax1.plot(epochs, val_loss,   label="Val Loss")
    ax1.set_title("Training & Validation Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(alpha=0.3)

    ax2.plot(epochs, train_acc, label="Train Acc@1")
    ax2.plot(epochs, val_acc,   label="Val Acc@1")
    ax2.set_title("Training & Validation Accuracy")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy (%)")
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(out, dpi=300)
    plt.close()
    print(f"Curves saved to {out}")


# plot_curves(history)
curves_path = os.path.join(checkpoint_dir, "training_curves.png")
plot_curves(history, curves_path)


Curves saved to curves.png


In [34]:
def evaluate():
    y_true, y_pred = [], []
    num_classes = len(train_data.classes)
    class_map = {str(i): name for i, name in enumerate(train_data.classes)}
    total_per_class   = [0] * num_classes
    correct_per_class = [0] * num_classes

    model.eval()
    start = time.time()
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs  = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(imgs)
            preds  = logits.argmax(dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

            for t, p in zip(labels, preds):
                total_per_class[t]   += 1
                correct_per_class[t] += int(t == p)

    elapsed = time.time() - start
    y_true  = np.array(y_true)
    y_pred  = np.array(y_pred)

    oa    = accuracy_score(y_true, y_pred)
    class_accs = [
        np.mean(y_pred[y_true == i] == y_true[y_true == i])
        for i in range(num_classes) if np.sum(y_true == i) > 0
    ]
    macc  = np.mean(class_accs)
    kappa = cohen_kappa_score(y_true, y_pred)
    prec  = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec   = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1    = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(f"Test completed. Samples: {len(y_true)}  Time: {elapsed:.2f}s")
    print(f"Overall Accuracy (OA) : {oa:.4f}")
    print(f"Mean Accuracy (mAcc)  : {macc:.4f}")
    print(f"Cohen-Kappa           : {kappa:.4f}")
    print(f"Precision (macro)     : {prec:.4f}")
    print(f"Recall    (macro)     : {rec:.4f}")
    print(f"F1-score  (macro)     : {f1:.4f}")

    print("Per-class accuracy:")
    for idx in range(num_classes):
        total = total_per_class[idx]
        correct = correct_per_class[idx]
        acc_cls = 100.0 * correct / total if total else 0.0
        cls_name = class_map.get(str(idx), f'class_{idx}')
        print(f" [{idx:02d}] {cls_name:<20s}: {acc_cls:6.2f}%  ({correct}/{total})")

evaluate()

Test completed. Samples: 359  Time: 8.93s
Overall Accuracy (OA) : 0.6741
Mean Accuracy (mAcc)  : 0.6670
Cohen-Kappa           : 0.5916
Precision (macro)     : 0.6726
Recall    (macro)     : 0.6670
F1-score  (macro)     : 0.6648
Per-class accuracy:
 [00] Beach               :  84.72%  (61/72)
 [01] Bridge              :  37.88%  (25/66)
 [02] Pond                :  74.03%  (57/77)
 [03] Port                :  60.87%  (42/69)
 [04] River               :  76.00%  (57/75)
