In [1]:
#!pip install timm
#!pip install ipywidgets

# new model - coatnet
hybrid architecture - Convolution + Attention.

In [2]:
import torch
import numpy as np
import evaluate
from datasets import load_from_disk
import timm
from transformers import (
    TrainingArguments, 
    Trainer
)
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop
)

DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "new_model_checkpoints"

# Using 'coatnet_0_rw_224' model, with '0' to use the smallest version to avoid overfitting:
MODEL_NAME = "coatnet_0_rw_224"

Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [3]:
print("Loading dataset.")
dataset = load_from_disk(DATA_PATH)

# transformations:
# normalization of values for ImageNet:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# training: random crops + flips (noise helps it learn):
_train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)), 
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

# validation - deterministic center crop
_val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def train_transforms(batch):
    batch["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

def val_transforms(batch):
    batch["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

# using .map() instead of .set_transform() because set_transform keeps the original columns accessible, and we want to delete those raw pngs to avoid program crashing.
print("Applying transforms and removing raw images...")

dataset["train"] = dataset["train"].map(
    train_transforms, 
    batched=True, 
    remove_columns=["image"]
)

dataset["validation"] = dataset["validation"].map(
    val_transforms, 
    batched=True, 
    remove_columns=["image"]
)

print("Data ready.")

Loading dataset.
Applying transforms and removing raw images...
Data ready.


# Training
Initializing CoAtNet model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Also:

1. High learning rate;
2. 15 epochs;
3. Uses regularization to prevent the model from memorizing the training images.

(using TimmTrainer class as the timm library and Hugging Face Trainer have different variable names (pixel_values vs x)).

In [4]:
import sys

# using the 'timm' library for CoAtNet: update: maybe not
model = timm.create_model(
    MODEL_NAME, 
    pretrained=False, 
    num_classes=200 
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

class TimmTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        if model.training:
            with torch.no_grad():
                preds = torch.argmax(logits, dim=-1)
                acc = (preds == labels).float().mean().item()
                
                self.epoch_train_loss += loss.item()
                self.epoch_train_acc += acc
                self.epoch_steps += 1
                
                if self.epoch_steps % 20 == 0:
                    current_epoch_float = self.state.epoch if self.state.epoch is not None else 0
                    print(f" >> Epoch: {current_epoch_float:.2f} | Batch: {self.epoch_steps} | Curr Loss: {loss.item():.4f}")
        
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        with torch.no_grad():
            logits = model(pixel_values)
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        if prediction_loss_only:
            return (loss, None, None)
        return (loss, logits, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        avg_train_loss = self.epoch_train_loss / self.epoch_steps if self.epoch_steps > 0 else 0
        avg_train_acc = self.epoch_train_acc / self.epoch_steps if self.epoch_steps > 0 else 0
        
        val_loss = metrics.get(f"{metric_key_prefix}_loss", 0.0)
        val_acc = metrics.get(f"{metric_key_prefix}_accuracy", 0.0)
        
        epoch_num = int(self.state.epoch) if self.state.epoch else 0
        
        print("\n" + "="*80)
        print(f" Training Loss:   {avg_train_loss:.4f} | Training Acc:   {avg_train_acc*100:.2f}%")
        print(f" Validation Loss: {val_loss:.4f}       | Validation Acc: {val_acc*100:.2f}%")
        print("="*80 + "\n")
        
        # resetting counters
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0
        
        return metrics

# accuracy metrics:
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

# training 
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=32, 
    num_train_epochs=30,
    learning_rate=5e-4,
    weight_decay=0.05,
    warmup_ratio=0.1,
    
    disable_tqdm=True, # for now bc i dont like it - Julia
    
    logging_strategy="epoch", 
    save_strategy="epoch",
    eval_strategy="epoch",
    
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="none"
)

# the actual training
trainer = TimmTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model("final_new_model")
print("Best model saved.")



 >> Epoch: 0.18 | Batch: 20 | Curr Loss: 5.2920
 >> Epoch: 0.37 | Batch: 40 | Curr Loss: 5.2192
 >> Epoch: 0.56 | Batch: 60 | Curr Loss: 5.2060
 >> Epoch: 0.75 | Batch: 80 | Curr Loss: 4.8901
 >> Epoch: 0.94 | Batch: 100 | Curr Loss: 4.9773
{'loss': 5.2121, 'grad_norm': 10.421049118041992, 'learning_rate': 0.0001650793650793651, 'epoch': 1.0}
{'eval_loss': 5.13273286819458, 'eval_accuracy': 0.015280135823429542, 'eval_runtime': 35.7113, 'eval_samples_per_second': 16.493, 'eval_steps_per_second': 2.072, 'epoch': 1.0}

 Training Loss:   5.2121 | Training Acc:   2.23%
 Validation Loss: 5.1327       | Validation Acc: 1.53%





 >> Epoch: 1.18 | Batch: 20 | Curr Loss: 4.7737
 >> Epoch: 1.37 | Batch: 40 | Curr Loss: 4.9461
 >> Epoch: 1.56 | Batch: 60 | Curr Loss: 4.8216
 >> Epoch: 1.75 | Batch: 80 | Curr Loss: 4.9474
 >> Epoch: 1.94 | Batch: 100 | Curr Loss: 4.9641
{'loss': 4.9344, 'grad_norm': 9.2814302444458, 'learning_rate': 0.00033174603174603175, 'epoch': 2.0}
{'eval_loss': 4.870268821716309, 'eval_accuracy': 0.035653650254668934, 'eval_runtime': 34.3857, 'eval_samples_per_second': 17.129, 'eval_steps_per_second': 2.152, 'epoch': 2.0}

 Training Loss:   4.9344 | Training Acc:   2.56%
 Validation Loss: 4.8703       | Validation Acc: 3.57%





 >> Epoch: 2.18 | Batch: 20 | Curr Loss: 4.5916
 >> Epoch: 2.37 | Batch: 40 | Curr Loss: 4.4413
 >> Epoch: 2.56 | Batch: 60 | Curr Loss: 4.8027
 >> Epoch: 2.75 | Batch: 80 | Curr Loss: 4.3070
 >> Epoch: 2.94 | Batch: 100 | Curr Loss: 4.5052
{'loss': 4.6218, 'grad_norm': 9.998618125915527, 'learning_rate': 0.0004984126984126984, 'epoch': 3.0}
{'eval_loss': 4.662646770477295, 'eval_accuracy': 0.04584040747028863, 'eval_runtime': 34.5556, 'eval_samples_per_second': 17.045, 'eval_steps_per_second': 2.141, 'epoch': 3.0}

 Training Loss:   4.6218 | Training Acc:   5.60%
 Validation Loss: 4.6626       | Validation Acc: 4.58%





 >> Epoch: 3.18 | Batch: 20 | Curr Loss: 4.0298
 >> Epoch: 3.37 | Batch: 40 | Curr Loss: 4.0914
 >> Epoch: 3.56 | Batch: 60 | Curr Loss: 3.9212
 >> Epoch: 3.75 | Batch: 80 | Curr Loss: 4.4908
 >> Epoch: 3.94 | Batch: 100 | Curr Loss: 3.8206
{'loss': 4.1612, 'grad_norm': 9.871589660644531, 'learning_rate': 0.000481657848324515, 'epoch': 4.0}
{'eval_loss': 4.229859352111816, 'eval_accuracy': 0.08488964346349745, 'eval_runtime': 33.6849, 'eval_samples_per_second': 17.486, 'eval_steps_per_second': 2.197, 'epoch': 4.0}

 Training Loss:   4.1612 | Training Acc:   9.60%
 Validation Loss: 4.2299       | Validation Acc: 8.49%





 >> Epoch: 4.18 | Batch: 20 | Curr Loss: 3.6918
 >> Epoch: 4.37 | Batch: 40 | Curr Loss: 3.4949
 >> Epoch: 4.56 | Batch: 60 | Curr Loss: 3.7689
 >> Epoch: 4.75 | Batch: 80 | Curr Loss: 3.7461
 >> Epoch: 4.94 | Batch: 100 | Curr Loss: 3.6884
{'loss': 3.7748, 'grad_norm': 12.703848838806152, 'learning_rate': 0.0004631393298059965, 'epoch': 5.0}
{'eval_loss': 4.140162467956543, 'eval_accuracy': 0.09168081494057725, 'eval_runtime': 33.5507, 'eval_samples_per_second': 17.556, 'eval_steps_per_second': 2.206, 'epoch': 5.0}

 Training Loss:   3.7748 | Training Acc:   13.78%
 Validation Loss: 4.1402       | Validation Acc: 9.17%





 >> Epoch: 5.18 | Batch: 20 | Curr Loss: 3.3414
 >> Epoch: 5.37 | Batch: 40 | Curr Loss: 3.5651
 >> Epoch: 5.56 | Batch: 60 | Curr Loss: 3.6964
 >> Epoch: 5.75 | Batch: 80 | Curr Loss: 3.9118
 >> Epoch: 5.94 | Batch: 100 | Curr Loss: 3.2535
{'loss': 3.4183, 'grad_norm': 12.894719123840332, 'learning_rate': 0.00044462081128747796, 'epoch': 6.0}
{'eval_loss': 3.8540635108947754, 'eval_accuracy': 0.14261460101867574, 'eval_runtime': 33.7321, 'eval_samples_per_second': 17.461, 'eval_steps_per_second': 2.194, 'epoch': 6.0}

 Training Loss:   3.4183 | Training Acc:   19.82%
 Validation Loss: 3.8541       | Validation Acc: 14.26%





 >> Epoch: 6.18 | Batch: 20 | Curr Loss: 2.8811
 >> Epoch: 6.37 | Batch: 40 | Curr Loss: 3.0321
 >> Epoch: 6.56 | Batch: 60 | Curr Loss: 2.7750
 >> Epoch: 6.75 | Batch: 80 | Curr Loss: 3.1262
 >> Epoch: 6.94 | Batch: 100 | Curr Loss: 2.4794
{'loss': 3.0809, 'grad_norm': 15.745542526245117, 'learning_rate': 0.00042610229276895945, 'epoch': 7.0}
{'eval_loss': 3.767374277114868, 'eval_accuracy': 0.15110356536502548, 'eval_runtime': 33.8883, 'eval_samples_per_second': 17.381, 'eval_steps_per_second': 2.184, 'epoch': 7.0}

 Training Loss:   3.0809 | Training Acc:   24.85%
 Validation Loss: 3.7674       | Validation Acc: 15.11%





 >> Epoch: 7.18 | Batch: 20 | Curr Loss: 2.3737
 >> Epoch: 7.37 | Batch: 40 | Curr Loss: 3.2167
 >> Epoch: 7.56 | Batch: 60 | Curr Loss: 3.0045
 >> Epoch: 7.75 | Batch: 80 | Curr Loss: 2.4818
 >> Epoch: 7.94 | Batch: 100 | Curr Loss: 2.5648
{'loss': 2.7353, 'grad_norm': 15.025665283203125, 'learning_rate': 0.00040758377425044093, 'epoch': 8.0}
{'eval_loss': 3.752377510070801, 'eval_accuracy': 0.17147707979626486, 'eval_runtime': 34.2292, 'eval_samples_per_second': 17.208, 'eval_steps_per_second': 2.162, 'epoch': 8.0}

 Training Loss:   2.7353 | Training Acc:   30.90%
 Validation Loss: 3.7524       | Validation Acc: 17.15%





 >> Epoch: 8.18 | Batch: 20 | Curr Loss: 2.2275
 >> Epoch: 8.37 | Batch: 40 | Curr Loss: 2.2859
 >> Epoch: 8.56 | Batch: 60 | Curr Loss: 2.7036
 >> Epoch: 8.75 | Batch: 80 | Curr Loss: 2.3940
 >> Epoch: 8.94 | Batch: 100 | Curr Loss: 2.4389
{'loss': 2.3561, 'grad_norm': 19.126829147338867, 'learning_rate': 0.0003890652557319224, 'epoch': 9.0}
{'eval_loss': 3.7386717796325684, 'eval_accuracy': 0.1833616298811545, 'eval_runtime': 34.0099, 'eval_samples_per_second': 17.318, 'eval_steps_per_second': 2.176, 'epoch': 9.0}

 Training Loss:   2.3561 | Training Acc:   38.10%
 Validation Loss: 3.7387       | Validation Acc: 18.34%





 >> Epoch: 9.18 | Batch: 20 | Curr Loss: 2.2218
 >> Epoch: 9.37 | Batch: 40 | Curr Loss: 1.6501
 >> Epoch: 9.56 | Batch: 60 | Curr Loss: 2.2966
 >> Epoch: 9.75 | Batch: 80 | Curr Loss: 2.0740
 >> Epoch: 9.94 | Batch: 100 | Curr Loss: 1.9161
{'loss': 2.0143, 'grad_norm': 21.23949432373047, 'learning_rate': 0.00037054673721340385, 'epoch': 10.0}
{'eval_loss': 3.689183235168457, 'eval_accuracy': 0.20203735144312393, 'eval_runtime': 34.2984, 'eval_samples_per_second': 17.173, 'eval_steps_per_second': 2.158, 'epoch': 10.0}

 Training Loss:   2.0143 | Training Acc:   46.40%
 Validation Loss: 3.6892       | Validation Acc: 20.20%





 >> Epoch: 10.18 | Batch: 20 | Curr Loss: 1.6331
 >> Epoch: 10.37 | Batch: 40 | Curr Loss: 1.5368
 >> Epoch: 10.56 | Batch: 60 | Curr Loss: 1.3764
 >> Epoch: 10.75 | Batch: 80 | Curr Loss: 1.7003
 >> Epoch: 10.94 | Batch: 100 | Curr Loss: 1.2809
{'loss': 1.6257, 'grad_norm': 14.597826957702637, 'learning_rate': 0.0003520282186948854, 'epoch': 11.0}
{'eval_loss': 3.713890790939331, 'eval_accuracy': 0.20882852292020374, 'eval_runtime': 35.9622, 'eval_samples_per_second': 16.378, 'eval_steps_per_second': 2.058, 'epoch': 11.0}

 Training Loss:   1.6257 | Training Acc:   56.59%
 Validation Loss: 3.7139       | Validation Acc: 20.88%





 >> Epoch: 11.18 | Batch: 20 | Curr Loss: 0.8422
 >> Epoch: 11.37 | Batch: 40 | Curr Loss: 1.3648
 >> Epoch: 11.56 | Batch: 60 | Curr Loss: 1.3099
 >> Epoch: 11.75 | Batch: 80 | Curr Loss: 1.3702
 >> Epoch: 11.94 | Batch: 100 | Curr Loss: 1.0156
{'loss': 1.2597, 'grad_norm': 13.238887786865234, 'learning_rate': 0.0003335097001763668, 'epoch': 12.0}
{'eval_loss': 3.6844232082366943, 'eval_accuracy': 0.23259762308998302, 'eval_runtime': 34.0985, 'eval_samples_per_second': 17.274, 'eval_steps_per_second': 2.17, 'epoch': 12.0}

 Training Loss:   1.2597 | Training Acc:   66.14%
 Validation Loss: 3.6844       | Validation Acc: 23.26%





 >> Epoch: 12.18 | Batch: 20 | Curr Loss: 0.7881
 >> Epoch: 12.37 | Batch: 40 | Curr Loss: 0.7608
 >> Epoch: 12.56 | Batch: 60 | Curr Loss: 0.8122
 >> Epoch: 12.75 | Batch: 80 | Curr Loss: 0.8678
 >> Epoch: 12.94 | Batch: 100 | Curr Loss: 0.8266
{'loss': 0.9233, 'grad_norm': 26.633113861083984, 'learning_rate': 0.00031499118165784836, 'epoch': 13.0}
{'eval_loss': 3.8944525718688965, 'eval_accuracy': 0.23089983022071306, 'eval_runtime': 33.9419, 'eval_samples_per_second': 17.353, 'eval_steps_per_second': 2.18, 'epoch': 13.0}

 Training Loss:   0.9233 | Training Acc:   75.72%
 Validation Loss: 3.8945       | Validation Acc: 23.09%





 >> Epoch: 13.18 | Batch: 20 | Curr Loss: 0.4332
 >> Epoch: 13.37 | Batch: 40 | Curr Loss: 0.4664
 >> Epoch: 13.56 | Batch: 60 | Curr Loss: 0.8525
 >> Epoch: 13.75 | Batch: 80 | Curr Loss: 0.9066
 >> Epoch: 13.94 | Batch: 100 | Curr Loss: 0.8918
{'loss': 0.6604, 'grad_norm': 19.621997833251953, 'learning_rate': 0.0002964726631393298, 'epoch': 14.0}
{'eval_loss': 3.9498450756073, 'eval_accuracy': 0.23769100169779286, 'eval_runtime': 33.9914, 'eval_samples_per_second': 17.328, 'eval_steps_per_second': 2.177, 'epoch': 14.0}

 Training Loss:   0.6604 | Training Acc:   83.49%
 Validation Loss: 3.9498       | Validation Acc: 23.77%





 >> Epoch: 14.18 | Batch: 20 | Curr Loss: 0.4451
 >> Epoch: 14.37 | Batch: 40 | Curr Loss: 0.4983
 >> Epoch: 14.56 | Batch: 60 | Curr Loss: 0.4781
 >> Epoch: 14.75 | Batch: 80 | Curr Loss: 0.5167
 >> Epoch: 14.94 | Batch: 100 | Curr Loss: 0.4029
{'loss': 0.4541, 'grad_norm': 22.221426010131836, 'learning_rate': 0.00027795414462081133, 'epoch': 15.0}
{'eval_loss': 4.094743728637695, 'eval_accuracy': 0.22241086587436332, 'eval_runtime': 34.1712, 'eval_samples_per_second': 17.237, 'eval_steps_per_second': 2.166, 'epoch': 15.0}

 Training Loss:   0.4541 | Training Acc:   89.97%
 Validation Loss: 4.0947       | Validation Acc: 22.24%





 >> Epoch: 15.18 | Batch: 20 | Curr Loss: 0.2231
 >> Epoch: 15.37 | Batch: 40 | Curr Loss: 0.2235
 >> Epoch: 15.56 | Batch: 60 | Curr Loss: 0.3530
 >> Epoch: 15.75 | Batch: 80 | Curr Loss: 0.4397
 >> Epoch: 15.94 | Batch: 100 | Curr Loss: 0.2933
{'loss': 0.3313, 'grad_norm': 38.03312301635742, 'learning_rate': 0.00025943562610229276, 'epoch': 16.0}
{'eval_loss': 4.138522624969482, 'eval_accuracy': 0.22580645161290322, 'eval_runtime': 34.2763, 'eval_samples_per_second': 17.184, 'eval_steps_per_second': 2.159, 'epoch': 16.0}

 Training Loss:   0.3313 | Training Acc:   92.71%
 Validation Loss: 4.1385       | Validation Acc: 22.58%





 >> Epoch: 16.18 | Batch: 20 | Curr Loss: 0.2428
 >> Epoch: 16.37 | Batch: 40 | Curr Loss: 0.1354
 >> Epoch: 16.56 | Batch: 60 | Curr Loss: 0.3285
 >> Epoch: 16.75 | Batch: 80 | Curr Loss: 0.3511
 >> Epoch: 16.94 | Batch: 100 | Curr Loss: 0.4046
{'loss': 0.2256, 'grad_norm': 11.726457595825195, 'learning_rate': 0.00024091710758377424, 'epoch': 17.0}
{'eval_loss': 4.2865800857543945, 'eval_accuracy': 0.21731748726655348, 'eval_runtime': 34.649, 'eval_samples_per_second': 16.999, 'eval_steps_per_second': 2.136, 'epoch': 17.0}

 Training Loss:   0.2256 | Training Acc:   95.70%
 Validation Loss: 4.2866       | Validation Acc: 21.73%





 >> Epoch: 17.18 | Batch: 20 | Curr Loss: 0.2826


KeyboardInterrupt: 

In [None]:
# test