In [1]:
#!pip install timm
#!pip install ipywidgets

# New model version
hybrid architecture - ConvNeXt

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_from_disk
import timm
import numpy as np
from tqdm.auto import tqdm
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop
)

DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "convnext_checkpoints"
MODEL_NAME = "convnext_tiny"
BATCH_SIZE = 32
EPOCHS = 25
LEARNING_RATE = 4e-3

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [3]:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)), 
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def apply_train_transforms(batch):
    batch["pixel_values"] = [train_transforms(img.convert("RGB")) for img in batch["image"]]
    return batch

def apply_val_transforms(batch):
    batch["pixel_values"] = [val_transforms(img.convert("RGB")) for img in batch["image"]]
    return batch

print("Loading and transforming data...")
dataset = load_from_disk(DATA_PATH)

dataset["train"].set_transform(apply_train_transforms)
dataset["validation"].set_transform(apply_val_transforms)

def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return pixel_values, labels

train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(dataset["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

Loading and transforming data...
Train batches: 105 | Val batches: 19


# Training
Initializing CoNeXt model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Also:

1. High learning rate;
2. 15 epochs;
3. Uses regularization to prevent the model from memorizing the training images.

(using TimmTrainer class as the timm library and Hugging Face Trainer have different variable names (pixel_values vs x)).

In [4]:
model = timm.create_model(
    MODEL_NAME, 
    pretrained=False,
    num_classes=200
)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# Training accuracy:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    # Progress bar for the batch
    pbar = tqdm(loader, desc="Training", leave=False)
    
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
        
        # Update progress bar
        pbar.set_postfix({"loss": loss.item()})
        
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            total_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
            
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


print("\nStarting Training.")
print(f"{'Epoch':<6} | {'Train Loss':<10} | {'Train Acc':<10} | {'Val Loss':<10} | {'Val Acc':<10}")
print("-" * 55)

best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    # Training
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    
    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    # Save the best model:
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "final_convnext_model.pth")
        saved_msg = "(*)"

    # So the best epoch will be saved
    else:
        saved_msg = ""
        
    print(f"{epoch:<6} | {train_loss:<10.4f} | {train_acc*100:<9.2f}% | {val_loss:<10.4f} | {val_acc*100:<9.2f}% {saved_msg}")

print("\nTraining Complete. Best Validation Accuracy: {:.2f}%".format(best_val_acc*100))


Starting Training.
Epoch  | Train Loss | Train Acc  | Val Loss   | Val Acc   
-------------------------------------------------------


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

1      | 5.7056     | 1.08     % | 5.2438     | 0.85     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

2      | 5.1264     | 1.17     % | 5.1334     | 0.85     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

3      | 5.0372     | 1.95     % | 5.0969     | 1.19     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

4      | 4.9746     | 2.19     % | 5.0303     | 2.21     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

5      | 4.8575     | 2.70     % | 4.9020     | 1.36     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

6      | 4.7404     | 2.94     % | 4.9221     | 2.04     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

7      | 4.6349     | 3.66     % | 4.7153     | 3.23     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

8      | 4.5758     | 4.11     % | 4.6843     | 2.89     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

9      | 4.4802     | 5.45     % | 4.6315     | 5.43     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

10     | 4.4366     | 5.00     % | 4.6008     | 3.74     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

11     | 4.3320     | 6.32     % | 4.6042     | 5.26     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

12     | 4.2674     | 7.25     % | 4.5950     | 5.26     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

13     | 4.2164     | 7.85     % | 4.5049     | 5.60     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

14     | 4.1394     | 8.21     % | 4.4888     | 6.11     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

15     | 4.0328     | 9.20     % | 4.4520     | 5.94     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

16     | 3.9577     | 10.19    % | 4.3858     | 7.13     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

17     | 3.8553     | 11.72    % | 4.4052     | 5.60     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

18     | 3.7696     | 12.35    % | 4.3408     | 8.15     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

19     | 3.6757     | 13.84    % | 4.3436     | 9.00     % (*)


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

20     | 3.5587     | 15.43    % | 4.3301     | 7.47     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

21     | 3.4670     | 17.35    % | 4.3846     | 7.13     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

22     | 3.3847     | 18.34    % | 4.3811     | 8.32     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

23     | 3.3065     | 19.96    % | 4.3969     | 7.64     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

24     | 3.2684     | 20.47    % | 4.4003     | 7.64     % 


Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

25     | 3.2458     | 20.98    % | 4.4075     | 7.30     % 

Training Complete. Best Validation Accuracy: 9.00%
