In [3]:
!pip install timm
!pip install ipywidgets



# new model - coatnet
hybrid architecture - Convolution + Attention.

In [4]:
import torch
import numpy as np
import evaluate
from datasets import load_from_disk
import timm
from transformers import (
    TrainingArguments, 
    Trainer
)
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop
)

DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "new_model_checkpoints"

# Using 'coatnet_0_rw_224' model, with '0' to use the smallest version to avoid overfitting:
MODEL_NAME = "coatnet_0_rw_224"

Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [5]:
print("Loading dataset.")
dataset = load_from_disk(DATA_PATH)

# transformations:
# normalization of values for ImageNet:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# training: random crops + flips (noise helps it learn):
_train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)), 
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

# validation - deterministic center crop
_val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def train_transforms(batch):
    batch["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

def val_transforms(batch):
    batch["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

# using .map() instead of .set_transform() because set_transform keeps the original columns accessible, and we want to delete those raw pngs to avoid program crashing.
print("Applying transforms and removing raw images...")

dataset["train"] = dataset["train"].map(
    train_transforms, 
    batched=True, 
    remove_columns=["image"]
)

dataset["validation"] = dataset["validation"].map(
    val_transforms, 
    batched=True, 
    remove_columns=["image"]
)

print("Data ready.")

Loading dataset.
Applying transforms and removing raw images...
Data ready.


# Training
Initializing CoAtNet model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Also:

1. High learning rate;
2. 15 epochs;
3. Uses regularization to prevent the model from memorizing the training images.

(using TimmTrainer class as the timm library and Hugging Face Trainer have different variable names (pixel_values vs x)).

In [None]:
import sys

# using the 'timm' library for CoAtNet: update: maybe not
model = timm.create_model(
    MODEL_NAME, 
    pretrained=False, 
    num_classes=200 
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

class TimmTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        if model.training:
            with torch.no_grad():
                preds = torch.argmax(logits, dim=-1)
                acc = (preds == labels).float().mean().item()
                
                self.epoch_train_loss += loss.item()
                self.epoch_train_acc += acc
                self.epoch_steps += 1
                
                if self.epoch_steps % 20 == 0:
                    current_epoch_float = self.state.epoch if self.state.epoch is not None else 0
                    print(f" >> Epoch: {current_epoch_float:.2f} | Batch: {self.epoch_steps} | Curr Loss: {loss.item():.4f}")
        
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        with torch.no_grad():
            logits = model(pixel_values)
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        if prediction_loss_only:
            return (loss, None, None)
        return (loss, logits, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        avg_train_loss = self.epoch_train_loss / self.epoch_steps if self.epoch_steps > 0 else 0
        avg_train_acc = self.epoch_train_acc / self.epoch_steps if self.epoch_steps > 0 else 0
        
        val_loss = metrics.get(f"{metric_key_prefix}_loss", 0.0)
        val_acc = metrics.get(f"{metric_key_prefix}_accuracy", 0.0)
        
        epoch_num = int(self.state.epoch) if self.state.epoch else 0
        
        print("\n" + "="*80)
        print(f" Training Loss:   {avg_train_loss:.4f} | Training Acc:   {avg_train_acc*100:.2f}%")
        print(f" Validation Loss: {val_loss:.4f}       | Validation Acc: {val_acc*100:.2f}%")
        print("="*80 + "\n")
        
        # resetting counters
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0
        
        return metrics

# accuracy metrics:
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

# training 
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=32, 
    num_train_epochs=15,
    learning_rate=5e-4,
    weight_decay=0.05,
    warmup_ratio=0.1,
    
    disable_tqdm=True, # for now bc i dont like it - Julia
    
    logging_strategy="epoch", 
    save_strategy="epoch",
    eval_strategy="epoch",
    
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="none"
)

# the actual training
trainer = TimmTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model("final_new_model")
print("Best model saved.")



In [None]:
# test