In [7]:
#!pip install timm
#!pip install ipywidgets

# New model version
hybrid architecture - Convolution + Attention.

In [None]:
import torch
import numpy as np
import evaluate
from datasets import load_from_disk
import timm
from transformers import (
    TrainingArguments, 
    Trainer
)
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop
)

DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "new_model_checkpoints"

# Using 'coatnet_0_rw_224' model, with '0' to use the smallest version to avoid overfitting:
MODEL_NAME = "coatnet_0_rw_224"

Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [9]:
print("Loading dataset.")
dataset = load_from_disk(DATA_PATH)

# Transformers:
# Normalization values for ImageNet:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Training: Random crops + Flips (noise helps it learn):
_train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)), 
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

# Validation - deterministic center crop
_val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def train_transforms(batch):
    batch["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

def val_transforms(batch):
    batch["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

# Use .map() instead of .set_transform() because set_transform keeps the original columns accessible, and we want to delete those raw pngs to avoid program crashing.
print("Applying transforms and removing raw images...")

dataset["train"] = dataset["train"].map(
    train_transforms, 
    batched=True, 
    remove_columns=["image"]
)

dataset["validation"] = dataset["validation"].map(
    val_transforms, 
    batched=True, 
    remove_columns=["image"]
)

print("Data ready.")

Loading dataset.
Applying transforms and removing raw images...
Data ready.


# Training
Initializing CoAtNet model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Also:

1. High learning rate;
2. 15 epochs;
3. Uses regularization to prevent the model from memorizing the training images.

(using TimmTrainer class as the timm library and Hugging Face Trainer have different variable names (pixel_values vs x)).

In [None]:
# Use 'timm' library for CoAtNet:
model = timm.create_model(
    MODEL_NAME, 
    pretrained=False, 
    num_classes=200 
)

device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)

# Making TIMM compatible with HuggingFace:
class TimmTrainer(Trainer):
    
    # Training Step (calculate Gradient Loss)
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        logits = model(pixel_values)
        
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        return (loss, logits) if return_outputs else loss
    
    # Validation/Prediction Step (getting the "unexpected keyword" error)
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        with torch.no_grad():
            # Only pass images to the model, not labels
            logits = model(pixel_values)
            
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            
        if prediction_loss_only:
            return (loss, None, None)
            
        return (loss, logits, labels)

# Accuracy metrics:
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    result = accuracy.compute(predictions=preds, references=p.label_ids)
    
    # THIS is where the print statement belongs!
    print(f"\n Epoch down! Validation Accuracy: {result['accuracy']*100:.2f}%")
    return result

# Training (taking batch size 32)
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=32,
    num_train_epochs=15,
    learning_rate=5e-4,
    weight_decay=0.05,
    warmup_ratio=0.1,
    
    # Checkpoints and logging:
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="none"
)

# Starting the actual training:
trainer = TimmTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

print("Starting the training.")
trainer.train()

trainer.save_model("final_new_model")
print("Saved. Close me.")

Starting the training.




Epoch,Training Loss,Validation Loss,Accuracy
1,5.053,5.123315,0.01528
2,4.7468,4.753685,0.032258
3,4.4123,4.5004,0.052632
4,3.8033,4.211554,0.08489
5,3.7027,4.044211,0.108659
6,3.3558,3.809991,0.135823
7,2.9299,3.822908,0.154499
8,2.5442,3.825937,0.149406
9,2.2277,3.633424,0.174873
10,1.951,3.776318,0.186757



 Epoch down! Validation Accuracy: 1.53%





 Epoch down! Validation Accuracy: 3.23%





 Epoch down! Validation Accuracy: 5.26%





 Epoch down! Validation Accuracy: 8.49%





 Epoch down! Validation Accuracy: 10.87%





 Epoch down! Validation Accuracy: 13.58%





 Epoch down! Validation Accuracy: 15.45%





 Epoch down! Validation Accuracy: 14.94%





 Epoch down! Validation Accuracy: 17.49%





 Epoch down! Validation Accuracy: 18.68%





 Epoch down! Validation Accuracy: 20.03%





 Epoch down! Validation Accuracy: 24.62%





 Epoch down! Validation Accuracy: 23.09%





 Epoch down! Validation Accuracy: 22.92%





 Epoch down! Validation Accuracy: 22.92%
Saved. Close me.
