# Cat/Dog Classification

In [None]:
import os

from_scratch = os.environ.get('FROM_SCRATCH', 'True') == 'True'
augment = os.environ.get('AUGMENT', 'False') == 'True'
mixup = os.environ.get('MIXUP', 'False') == 'True'

print(f'from_scratch: {from_scratch}')
print(f'augment: {augment}')
print(f'mixup: {mixup}')

if from_scratch:
    num_epochs = 1000
    lr = 1e-4
    min_lr = 1e-4
    weight_decay = 1e-2
else:
    num_epochs = 3
    lr = 1e-5
    min_lr = 1e-5
    weight_decay = 1e-2

early_stop = 100

## Dataset

In [None]:
import torch
import platform
import numpy as np
import pandas as pd
import torch.nn as nn
from PIL import Image
from glob import glob
from tabulate import tabulate
from torchinfo import summary
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import Dataset, DataLoader, default_collate
from sklearn.model_selection import train_test_split, StratifiedKFold

from metaformer import create_my_caformer, create_standard_caformer


def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)


seed_everything(seed=42)

# Load data into a dataframe
cat_files = glob('./data/cat/*.jpg')
dog_files = glob('./data/dog/*.jpg')
cat_files.sort()
dog_files.sort()
df = pd.DataFrame(
    {
        'filename': np.concatenate([cat_files, dog_files]),
        'class': ['cat'] * len(cat_files) + ['dog'] * len(dog_files),
    }
)

# Split data into training and testing sets, stratified by class
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['class'])

# Split training set into 5 folds, stratified by class (only fold 0 is used in this notebook)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
train_df['fold'] = -1
for i, (_, test_index) in enumerate(skf.split(train_df, train_df['class'])):
    train_df.iloc[test_index, -1] = i
table = pd.DataFrame(
    {
        f'fold{i}': train_df[train_df['fold'] == i].groupby('class').size()
        for i in range(5)
    }
    | {'test': test_df.groupby('class').size()}
)
print(tabulate(table, headers='keys'))

# Show random images in the training set
cat_files = train_df[train_df['class'] == 'cat']['filename'].sample(5)
dog_files = train_df[train_df['class'] == 'dog']['filename'].sample(5)
fig, (cat_axes, dog_axes) = plt.subplots(2, 5, figsize=(5, 2))
for i, (cat_file, dog_file) in enumerate(zip(cat_files, dog_files)):
    cat_image = Image.open(cat_file)
    dog_image = Image.open(dog_file)
    cat_axes[i].imshow(cat_image)
    dog_axes[i].imshow(dog_image)
    cat_axes[i].axis('off')
    dog_axes[i].axis('off')

In [None]:
batch_size = 32
input_size = 384

# Pre-computed mean and standard deviation
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
basic_transforms = [
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean, std),
]

num_workers = 0 if platform.node() == 'nedsmac.local' else 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CatDogDataset(Dataset):
    def __init__(self, filenames, labels, transform=None):
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image = Image.open(self.filenames[idx])
        label = 1 if self.labels[idx] == 'dog' else 0
        if self.transform:
            image = self.transform(image)
        return image, label


# Prepare training set with RandAugment, MixUp, and CutMix
train_transform = v2.Compose(
    [
        v2.RandomResizedCrop(
            input_size,
            scale=(0.8, 1.0),
            ratio=(0.8, 1.2),
        ),
        v2.RandomHorizontalFlip(),
    ]
    + ([v2.RandAugment()] if augment else [])
    + basic_transforms
)
if mixup:
    mixup = v2.MixUp(num_classes=2, alpha=0.2)
    collate_fn = lambda batch: mixup(*default_collate(batch))  # noqa: E731
else:
    collate_fn = default_collate
train_dataset = CatDogDataset(
    train_df[train_df['fold'] != 0]['filename'].values,
    train_df[train_df['fold'] != 0]['class'].values,
    train_transform,
)
train_loader = DataLoader(
    train_dataset,
    batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
)

# Prepare testing and validation set with standard transforms
eval_transform = v2.Compose(
    [v2.Resize(input_size), v2.CenterCrop(input_size)] + basic_transforms
)
val_dataset = CatDogDataset(
    train_df[train_df['fold'] == 0]['filename'].values,
    train_df[train_df['fold'] == 0]['class'].values,
    eval_transform,
)
test_dataset = CatDogDataset(
    test_df['filename'].values, test_df['class'].values, eval_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size,
    shuffle=False,
    num_workers=num_workers,
)
test_loader = DataLoader(
    test_dataset,
    batch_size,
    shuffle=False,
    num_workers=num_workers,
)

# Show a batch of transformed images from the training/validation set
for loader in [train_loader, val_loader]:
    images, labels = next(iter(loader))

    # Undo normalization
    images *= torch.tensor(std).view(1, 3, 1, 1)
    images += torch.tensor(mean).view(1, 3, 1, 1)
    images = torch.clamp(images, 0, 1)

    fig, axes = plt.subplots(4, 8, figsize=(8, 4))
    for i, (image, label) in enumerate(zip(images, labels)):
        axes[i // 8, i % 8].imshow(image.permute(1, 2, 0))
        axes[i // 8, i % 8].axis('off')
    plt.tight_layout()

## Train

In [None]:
if from_scratch:
    model = create_my_caformer()
else:
    model = create_standard_caformer()

    # Load pretrained weights (https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth)
    state_dict = torch.load('./caformer_s18_384.pth', map_location='cpu')
    head_param_names = [
        name for name, param in model.named_parameters() if name.startswith('head')
    ]
    for name in head_param_names:
        del state_dict[name]
    model.load_state_dict(state_dict, strict=False)

    # Freeze all layers except the MLP head
    for name, param in model.named_parameters():
        if not name.startswith('head'):
            param.requires_grad_(False)

model = model.to(device)
print(summary(model, input_size=(1, 3, input_size, input_size)))

In [None]:
model_name = 'model_from_scratch' if from_scratch else 'model_finetuned'
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, min_lr)


def validate():
    model.eval()
    val_loss = 0
    y_pred = []
    y_true = []
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.argmax(dim=1)
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
    val_loss /= len(val_loader.dataset)
    val_acc = accuracy_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred, average='macro')
    print(
        f' | Val. Loss: {val_loss:.6f}, Val. Acc.: {val_acc:.6f}, Val. F1: {val_f1:.6f}'
    )
    return val_loss, val_acc, val_f1


results = []
best_val_loss = np.inf
step = 0
best_epoch = 0
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_acc = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        print(
            f'\rEpoch {epoch + 1}/{num_epochs} | '
            f'Batch {i + 1}/{len(train_loader)}, '
            f'Loss: {loss.item():.6f} | '
            f'LR {optimizer.param_groups[0]["lr"]:.6f}',
            end='',
        )
        if not from_scratch and (step + 1) % 5 == 0:
            val_loss, val_acc, val_f1 = validate()
            results.append([step + 1, loss.item(), val_loss, val_acc, val_f1])
        step += 1
    train_loss /= len(train_loader.dataset)
    if from_scratch:
        print(f' | Train. Loss: {train_loss:.6f}', end='')
        val_loss, val_acc, val_f1 = validate()
        results.append([epoch + 1, train_loss, val_loss, val_acc, val_f1])
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f'/output/{model_name}.pth')
        best_epoch = epoch
    scheduler.step()
    pd.DataFrame(
        results, columns=['epoch', 'train_loss', 'val_loss', 'val_acc', 'val_f1']
    ).to_csv(f'/output/{model_name}.csv', index=False)
    if epoch - best_epoch >= early_stop:
        print(f'\nEarly stopping at epoch {epoch + 1}')
        break
print(f'\nBest Epoch: {best_epoch + 1} | Best Val. Loss: {best_val_loss:.6f}')

## Test

In [None]:
if from_scratch:
    model = create_my_caformer()
    model.load_state_dict(
        torch.load('/output/model_from_scratch.pth', map_location='cpu')
    )
else:
    model = create_standard_caformer()
    model.load_state_dict(torch.load('/output/model_finetuned.pth', map_location='cpu'))
model = model.to(device)


def test(loader):
    model.eval()
    y_pred = []
    y_true = []
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
    test_acc = accuracy_score(y_true, y_pred)
    test_f1 = f1_score(y_true, y_pred, average='macro')
    print(f'Acc.: {test_acc:.6f}, F1: {test_f1:.6f}')


print('Validation')
test(val_loader)
print('Test')
test(test_loader)