In [1]:
#!pip install timm
#!pip install ipywidgets
#!pip install optuna plotly
#!pip install kaleido
#!pip install matplotlib

# new model - coatnet
hybrid architecture - Convolution + Attention.

In [3]:
import torch
import numpy as np
import pandas as pd
import evaluate
import timm
import shutil
import os
import safetensors.torch
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import TrainingArguments, Trainer, set_seed
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop, TrivialAugmentWide, InterpolationMode
)

In [None]:
SEED = 100
MODEL_NAME = "coatnet_0_rw_224"
DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "coatnet_aug_experiment"
NUM_EPOCHS = 30
BATCH_SIZE = 32

# Hyperparameters (Optimized)
LEARNING_RATE = 5.9e-4
WEIGHT_DECAY = 0.065

# Config for augmentation
mixup_args = {
    # Strength of blending
    'mixup_alpha': 0.8,
    # Strength of cutting/pasting
    'cutmix_alpha': 1.0,
    # Apply it to 100% of batches
    'prob': 1.0,
    # 50% chance of Mixup vs 50% CutMix
    'switch_prob': 0.5,
    # Apply to whole batch
    'mode': 'batch',
    # Smoothing
    'label_smoothing': 0.1,
    'num_classes': 200
}

mixup_fn = Mixup(**mixup_args)

if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)

set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available(): device = "mps"
print(f"Using device: {device} | Seed: {SEED}")

Using device: mps | Seed: 100


Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [4]:
dataset = load_from_disk(DATA_PATH)

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

_train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0), interpolation=InterpolationMode.BICUBIC),
    RandomHorizontalFlip(),
    TrivialAugmentWide(),
    ToTensor(),
    normalize,
])

_val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def train_transforms(batch):
    batch["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in batch["image"]]
    if "image" in batch:
        del batch["image"]
    return batch

def val_transforms(batch):
    batch["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in batch["image"]]
    if "image" in batch:
        del batch["image"]
    return batch

dataset["train"].set_transform(train_transforms)
dataset["validation"].set_transform(val_transforms)

# Custom trainer 
class SteroidTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        mixup_is_active = model.training and mixup_fn is not None
        
        original_labels = labels.clone().detach()
        
        # Mixup/CutMix Transformation (labels become soft targets)
        if mixup_is_active:
            pixel_values, labels = mixup_fn(pixel_values, labels)
        
        logits = model(pixel_values)
        
        loss_fct = SoftTargetCrossEntropy() if mixup_is_active else torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        if model.training:
            with torch.no_grad():
                preds = torch.argmax(logits, dim=-1)
                
                current_acc = (preds == original_labels).float().mean().item()
                
                self.epoch_train_loss += loss.item()
                self.epoch_train_acc += current_acc
                self.epoch_steps += 1
                
                if self.epoch_steps % 20 == 0:
                    current_epoch_float = self.state.epoch if self.state.epoch is not None else 0
                    print(f" >> Epoch: {current_epoch_float:.2f} | Batch: {self.epoch_steps} | Curr Loss: {loss.item():.4f}")
        
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        labels = inputs.pop("labels", None)
        
        with torch.no_grad():
            outputs = model(inputs["pixel_values"])
            loss = None
            if not prediction_loss_only and labels is not None:
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(outputs, labels)

        return (loss, outputs, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        avg_train_loss = self.epoch_train_loss / self.epoch_steps if self.epoch_steps > 0 else 0
        avg_train_acc = self.epoch_train_acc / self.epoch_steps if self.epoch_steps > 0 else 0
        
        val_loss = metrics.get(f"{metric_key_prefix}_loss", 0.0)
        val_acc = metrics.get(f"{metric_key_prefix}_accuracy", 0.0)
        
        print("\n" + "="*80)
        print(f" EPOCH {self.state.epoch:.2f} RESULTS:")
        print(f" Training Loss:   {avg_train_loss:.4f} | Training Acc:   {avg_train_acc*100:.2f}% (Hard Labels)")
        print(f" Validation Loss: {val_loss:.4f}       | Validation Acc: {val_acc*100:.2f}% (Actual performance)")
        print("="*80 + "\n")
        
        # Reset counters for the next epoch
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0
        return metrics

# Training
Initializing CoAtNet model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Using Bayesian optimisation.

In [None]:
print(f"Initializing {MODEL_NAME} from scratch.")
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
model.to(device)

accuracy = evaluate.load("accuracy")
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE, 
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=0.1,
    
    dataloader_drop_last=True,
    remove_unused_columns=False,
    eval_strategy="epoch",
    
    logging_strategy="steps", 
    logging_steps=50,
    
    save_strategy="epoch",
    load_best_model_at_end=True,     
    metric_for_best_model="accuracy",
    save_total_limit=1,
    disable_tqdm=False,
    report_to="none",
    seed=SEED
)

trainer = SteroidTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

print("Starting Training (Mixup + CutMix + TrivialAugment).")
trainer.train()

final_path = f"{OUTPUT_DIR}/best_aug_model"
trainer.save_model(final_path)

Initializing coatnet_0_rw_224 from scratch.
Starting Training (Mixup + CutMix + TrivialAugment).




Epoch,Training Loss,Validation Loss,Accuracy
1,5.3275,5.228301,0.010417
2,5.2719,5.146622,0.019097
3,5.2215,5.099582,0.027778
4,5.1493,5.016776,0.015625
5,5.1234,4.921926,0.03125
6,5.0852,4.838113,0.022569
7,5.0608,4.768239,0.038194
8,4.9903,4.743011,0.041667
9,4.9218,4.586624,0.043403
10,4.9438,4.590762,0.055556


 >> Epoch: 0.18 | Batch: 20 | Curr Loss: 5.3284
 >> Epoch: 0.38 | Batch: 40 | Curr Loss: 5.1422
 >> Epoch: 0.57 | Batch: 60 | Curr Loss: 5.2223
 >> Epoch: 0.76 | Batch: 80 | Curr Loss: 5.2385
 >> Epoch: 0.95 | Batch: 100 | Curr Loss: 5.4777

 EPOCH 1.00 RESULTS:
 Training Loss:   5.3406 | Training Acc:   1.14% (Hard Labels)
 Validation Loss: 5.2283       | Validation Acc: 1.04% (Actual performance)





 >> Epoch: 1.18 | Batch: 20 | Curr Loss: 5.0941
 >> Epoch: 1.38 | Batch: 40 | Curr Loss: 5.3676
 >> Epoch: 1.57 | Batch: 60 | Curr Loss: 5.2051
 >> Epoch: 1.76 | Batch: 80 | Curr Loss: 5.2424
 >> Epoch: 1.95 | Batch: 100 | Curr Loss: 5.2321

 EPOCH 2.00 RESULTS:
 Training Loss:   5.2603 | Training Acc:   1.14% (Hard Labels)
 Validation Loss: 5.1466       | Validation Acc: 1.91% (Actual performance)





 >> Epoch: 2.18 | Batch: 20 | Curr Loss: 5.3741
 >> Epoch: 2.38 | Batch: 40 | Curr Loss: 5.1970
 >> Epoch: 2.57 | Batch: 60 | Curr Loss: 5.0385
 >> Epoch: 2.76 | Batch: 80 | Curr Loss: 5.0909
 >> Epoch: 2.95 | Batch: 100 | Curr Loss: 5.2380

 EPOCH 3.00 RESULTS:
 Training Loss:   5.2254 | Training Acc:   0.93% (Hard Labels)
 Validation Loss: 5.0996       | Validation Acc: 2.78% (Actual performance)





 >> Epoch: 3.18 | Batch: 20 | Curr Loss: 5.1811
 >> Epoch: 3.38 | Batch: 40 | Curr Loss: 5.2030
 >> Epoch: 3.57 | Batch: 60 | Curr Loss: 5.1232
 >> Epoch: 3.76 | Batch: 80 | Curr Loss: 5.0403
 >> Epoch: 3.95 | Batch: 100 | Curr Loss: 5.1120

 EPOCH 4.00 RESULTS:
 Training Loss:   5.1618 | Training Acc:   1.53% (Hard Labels)
 Validation Loss: 5.0168       | Validation Acc: 1.56% (Actual performance)





 >> Epoch: 4.18 | Batch: 20 | Curr Loss: 5.2615
 >> Epoch: 4.38 | Batch: 40 | Curr Loss: 4.9808
 >> Epoch: 4.57 | Batch: 60 | Curr Loss: 5.2089
 >> Epoch: 4.76 | Batch: 80 | Curr Loss: 5.1676
 >> Epoch: 4.95 | Batch: 100 | Curr Loss: 4.8812

 EPOCH 5.00 RESULTS:
 Training Loss:   5.1165 | Training Acc:   2.19% (Hard Labels)
 Validation Loss: 4.9219       | Validation Acc: 3.12% (Actual performance)





 >> Epoch: 5.18 | Batch: 20 | Curr Loss: 4.9225
 >> Epoch: 5.38 | Batch: 40 | Curr Loss: 4.9580
 >> Epoch: 5.57 | Batch: 60 | Curr Loss: 5.2457
 >> Epoch: 5.76 | Batch: 80 | Curr Loss: 5.1631
 >> Epoch: 5.95 | Batch: 100 | Curr Loss: 5.0968

 EPOCH 6.00 RESULTS:
 Training Loss:   5.0691 | Training Acc:   2.31% (Hard Labels)
 Validation Loss: 4.8381       | Validation Acc: 2.26% (Actual performance)





 >> Epoch: 6.18 | Batch: 20 | Curr Loss: 5.1279
 >> Epoch: 6.38 | Batch: 40 | Curr Loss: 5.0826
 >> Epoch: 6.57 | Batch: 60 | Curr Loss: 4.8957
 >> Epoch: 6.76 | Batch: 80 | Curr Loss: 4.9634
 >> Epoch: 6.95 | Batch: 100 | Curr Loss: 5.1164

 EPOCH 7.00 RESULTS:
 Training Loss:   5.0316 | Training Acc:   2.52% (Hard Labels)
 Validation Loss: 4.7682       | Validation Acc: 3.82% (Actual performance)





 >> Epoch: 7.18 | Batch: 20 | Curr Loss: 4.7849
 >> Epoch: 7.38 | Batch: 40 | Curr Loss: 4.9159
 >> Epoch: 7.57 | Batch: 60 | Curr Loss: 5.2700
 >> Epoch: 7.76 | Batch: 80 | Curr Loss: 4.8832
 >> Epoch: 7.95 | Batch: 100 | Curr Loss: 5.1187

 EPOCH 8.00 RESULTS:
 Training Loss:   4.9977 | Training Acc:   3.00% (Hard Labels)
 Validation Loss: 4.7430       | Validation Acc: 4.17% (Actual performance)





 >> Epoch: 8.18 | Batch: 20 | Curr Loss: 5.0009
 >> Epoch: 8.38 | Batch: 40 | Curr Loss: 5.1282
 >> Epoch: 8.57 | Batch: 60 | Curr Loss: 4.8713
 >> Epoch: 8.76 | Batch: 80 | Curr Loss: 4.6895
 >> Epoch: 8.95 | Batch: 100 | Curr Loss: 5.0385

 EPOCH 9.00 RESULTS:
 Training Loss:   4.9429 | Training Acc:   3.31% (Hard Labels)
 Validation Loss: 4.5866       | Validation Acc: 4.34% (Actual performance)





 >> Epoch: 9.18 | Batch: 20 | Curr Loss: 4.7641
 >> Epoch: 9.38 | Batch: 40 | Curr Loss: 4.8847
 >> Epoch: 9.57 | Batch: 60 | Curr Loss: 4.9598
 >> Epoch: 9.76 | Batch: 80 | Curr Loss: 5.0098
 >> Epoch: 9.95 | Batch: 100 | Curr Loss: 4.6824

 EPOCH 10.00 RESULTS:
 Training Loss:   4.9445 | Training Acc:   3.73% (Hard Labels)
 Validation Loss: 4.5908       | Validation Acc: 5.56% (Actual performance)





 >> Epoch: 10.18 | Batch: 20 | Curr Loss: 4.9774
 >> Epoch: 10.38 | Batch: 40 | Curr Loss: 4.5868
 >> Epoch: 10.57 | Batch: 60 | Curr Loss: 4.9537
 >> Epoch: 10.76 | Batch: 80 | Curr Loss: 4.8887
 >> Epoch: 10.95 | Batch: 100 | Curr Loss: 4.7312

 EPOCH 11.00 RESULTS:
 Training Loss:   4.8987 | Training Acc:   3.28% (Hard Labels)
 Validation Loss: 4.5497       | Validation Acc: 5.38% (Actual performance)





 >> Epoch: 11.18 | Batch: 20 | Curr Loss: 5.0499
 >> Epoch: 11.38 | Batch: 40 | Curr Loss: 5.1224
 >> Epoch: 11.57 | Batch: 60 | Curr Loss: 5.1964
 >> Epoch: 11.76 | Batch: 80 | Curr Loss: 5.1522
 >> Epoch: 11.95 | Batch: 100 | Curr Loss: 4.7833

 EPOCH 12.00 RESULTS:
 Training Loss:   4.9244 | Training Acc:   3.55% (Hard Labels)
 Validation Loss: 4.4977       | Validation Acc: 4.86% (Actual performance)





 >> Epoch: 12.18 | Batch: 20 | Curr Loss: 4.6203
 >> Epoch: 12.38 | Batch: 40 | Curr Loss: 4.5190
 >> Epoch: 12.57 | Batch: 60 | Curr Loss: 4.7520
 >> Epoch: 12.76 | Batch: 80 | Curr Loss: 4.8804
 >> Epoch: 12.95 | Batch: 100 | Curr Loss: 4.9350

 EPOCH 13.00 RESULTS:
 Training Loss:   4.8090 | Training Acc:   4.72% (Hard Labels)
 Validation Loss: 4.4222       | Validation Acc: 5.73% (Actual performance)





 >> Epoch: 13.18 | Batch: 20 | Curr Loss: 4.6505
 >> Epoch: 13.38 | Batch: 40 | Curr Loss: 4.7663
 >> Epoch: 13.57 | Batch: 60 | Curr Loss: 4.7928
 >> Epoch: 13.76 | Batch: 80 | Curr Loss: 4.5841
 >> Epoch: 13.95 | Batch: 100 | Curr Loss: 5.1248

 EPOCH 14.00 RESULTS:
 Training Loss:   4.8211 | Training Acc:   4.45% (Hard Labels)
 Validation Loss: 4.4001       | Validation Acc: 5.73% (Actual performance)





 >> Epoch: 14.18 | Batch: 20 | Curr Loss: 4.9954
 >> Epoch: 14.38 | Batch: 40 | Curr Loss: 4.5087
 >> Epoch: 14.57 | Batch: 60 | Curr Loss: 4.8031
 >> Epoch: 14.76 | Batch: 80 | Curr Loss: 5.0589
 >> Epoch: 14.95 | Batch: 100 | Curr Loss: 4.8296

 EPOCH 15.00 RESULTS:
 Training Loss:   4.8034 | Training Acc:   4.84% (Hard Labels)
 Validation Loss: 4.3946       | Validation Acc: 6.08% (Actual performance)





 >> Epoch: 15.18 | Batch: 20 | Curr Loss: 4.5038
 >> Epoch: 15.38 | Batch: 40 | Curr Loss: 5.1793
 >> Epoch: 15.57 | Batch: 60 | Curr Loss: 4.9657
 >> Epoch: 15.76 | Batch: 80 | Curr Loss: 4.6481
 >> Epoch: 15.95 | Batch: 100 | Curr Loss: 4.9058

 EPOCH 16.00 RESULTS:
 Training Loss:   4.8242 | Training Acc:   5.26% (Hard Labels)
 Validation Loss: 4.3666       | Validation Acc: 6.77% (Actual performance)





 >> Epoch: 16.18 | Batch: 20 | Curr Loss: 5.0342
 >> Epoch: 16.38 | Batch: 40 | Curr Loss: 4.8528
 >> Epoch: 16.57 | Batch: 60 | Curr Loss: 4.4190
 >> Epoch: 16.76 | Batch: 80 | Curr Loss: 5.0489
 >> Epoch: 16.95 | Batch: 100 | Curr Loss: 4.8498

 EPOCH 17.00 RESULTS:
 Training Loss:   4.7599 | Training Acc:   5.11% (Hard Labels)
 Validation Loss: 4.3237       | Validation Acc: 6.25% (Actual performance)





 >> Epoch: 17.18 | Batch: 20 | Curr Loss: 4.8055
 >> Epoch: 17.38 | Batch: 40 | Curr Loss: 4.6764
 >> Epoch: 17.57 | Batch: 60 | Curr Loss: 4.6231
 >> Epoch: 17.76 | Batch: 80 | Curr Loss: 4.9354
 >> Epoch: 17.95 | Batch: 100 | Curr Loss: 4.5179

 EPOCH 18.00 RESULTS:
 Training Loss:   4.7343 | Training Acc:   5.62% (Hard Labels)
 Validation Loss: 4.2337       | Validation Acc: 9.38% (Actual performance)





 >> Epoch: 18.18 | Batch: 20 | Curr Loss: 4.5837
 >> Epoch: 18.38 | Batch: 40 | Curr Loss: 4.9170
 >> Epoch: 18.57 | Batch: 60 | Curr Loss: 4.7589
 >> Epoch: 18.76 | Batch: 80 | Curr Loss: 4.7839
 >> Epoch: 18.95 | Batch: 100 | Curr Loss: 4.8482

 EPOCH 19.00 RESULTS:
 Training Loss:   4.7445 | Training Acc:   5.23% (Hard Labels)
 Validation Loss: 4.2151       | Validation Acc: 9.20% (Actual performance)





 >> Epoch: 19.18 | Batch: 20 | Curr Loss: 4.4944
 >> Epoch: 19.38 | Batch: 40 | Curr Loss: 4.7015
 >> Epoch: 19.57 | Batch: 60 | Curr Loss: 4.6164
 >> Epoch: 19.76 | Batch: 80 | Curr Loss: 4.9642
 >> Epoch: 19.95 | Batch: 100 | Curr Loss: 4.5378

 EPOCH 20.00 RESULTS:
 Training Loss:   4.6777 | Training Acc:   6.37% (Hard Labels)
 Validation Loss: 4.2110       | Validation Acc: 9.55% (Actual performance)





 >> Epoch: 20.18 | Batch: 20 | Curr Loss: 4.8858
 >> Epoch: 20.38 | Batch: 40 | Curr Loss: 4.7485
 >> Epoch: 20.57 | Batch: 60 | Curr Loss: 4.6491
 >> Epoch: 20.76 | Batch: 80 | Curr Loss: 4.9429
 >> Epoch: 20.95 | Batch: 100 | Curr Loss: 4.6736

 EPOCH 21.00 RESULTS:
 Training Loss:   4.6657 | Training Acc:   6.79% (Hard Labels)
 Validation Loss: 4.1237       | Validation Acc: 11.11% (Actual performance)





 >> Epoch: 21.18 | Batch: 20 | Curr Loss: 4.6809
 >> Epoch: 21.38 | Batch: 40 | Curr Loss: 5.1859
 >> Epoch: 21.57 | Batch: 60 | Curr Loss: 4.5069
 >> Epoch: 21.76 | Batch: 80 | Curr Loss: 4.0186
 >> Epoch: 21.95 | Batch: 100 | Curr Loss: 4.4441

 EPOCH 22.00 RESULTS:
 Training Loss:   4.6443 | Training Acc:   7.78% (Hard Labels)
 Validation Loss: 4.0702       | Validation Acc: 10.76% (Actual performance)





 >> Epoch: 22.18 | Batch: 20 | Curr Loss: 4.8921
 >> Epoch: 22.38 | Batch: 40 | Curr Loss: 4.6522
 >> Epoch: 22.57 | Batch: 60 | Curr Loss: 4.9438
 >> Epoch: 22.76 | Batch: 80 | Curr Loss: 4.1995
 >> Epoch: 22.95 | Batch: 100 | Curr Loss: 4.4394

 EPOCH 23.00 RESULTS:
 Training Loss:   4.5823 | Training Acc:   7.36% (Hard Labels)
 Validation Loss: 4.0845       | Validation Acc: 11.46% (Actual performance)





 >> Epoch: 23.18 | Batch: 20 | Curr Loss: 4.6631
 >> Epoch: 23.38 | Batch: 40 | Curr Loss: 4.9280
 >> Epoch: 23.57 | Batch: 60 | Curr Loss: 4.9148
 >> Epoch: 23.76 | Batch: 80 | Curr Loss: 4.8515
 >> Epoch: 23.95 | Batch: 100 | Curr Loss: 4.3205

 EPOCH 24.00 RESULTS:
 Training Loss:   4.6005 | Training Acc:   8.74% (Hard Labels)
 Validation Loss: 3.9971       | Validation Acc: 12.33% (Actual performance)





 >> Epoch: 24.18 | Batch: 20 | Curr Loss: 4.8596
 >> Epoch: 24.38 | Batch: 40 | Curr Loss: 4.8972
 >> Epoch: 24.57 | Batch: 60 | Curr Loss: 4.4425
 >> Epoch: 24.76 | Batch: 80 | Curr Loss: 4.7548
 >> Epoch: 24.95 | Batch: 100 | Curr Loss: 4.6263

 EPOCH 25.00 RESULTS:
 Training Loss:   4.6044 | Training Acc:   8.08% (Hard Labels)
 Validation Loss: 3.9945       | Validation Acc: 13.54% (Actual performance)





 >> Epoch: 25.18 | Batch: 20 | Curr Loss: 4.9159
 >> Epoch: 25.38 | Batch: 40 | Curr Loss: 4.0955
 >> Epoch: 25.57 | Batch: 60 | Curr Loss: 4.8268
 >> Epoch: 25.76 | Batch: 80 | Curr Loss: 4.4157
 >> Epoch: 25.95 | Batch: 100 | Curr Loss: 4.7687

 EPOCH 26.00 RESULTS:
 Training Loss:   4.5632 | Training Acc:   7.60% (Hard Labels)
 Validation Loss: 3.9268       | Validation Acc: 12.33% (Actual performance)





 >> Epoch: 26.18 | Batch: 20 | Curr Loss: 4.0765
 >> Epoch: 26.38 | Batch: 40 | Curr Loss: 4.8846
 >> Epoch: 26.57 | Batch: 60 | Curr Loss: 4.3076
 >> Epoch: 26.76 | Batch: 80 | Curr Loss: 4.4773
 >> Epoch: 26.95 | Batch: 100 | Curr Loss: 4.9845

 EPOCH 27.00 RESULTS:
 Training Loss:   4.5483 | Training Acc:   8.47% (Hard Labels)
 Validation Loss: 3.9603       | Validation Acc: 12.15% (Actual performance)





 >> Epoch: 27.18 | Batch: 20 | Curr Loss: 4.8756
 >> Epoch: 27.38 | Batch: 40 | Curr Loss: 4.6545
 >> Epoch: 27.57 | Batch: 60 | Curr Loss: 3.7881
 >> Epoch: 27.76 | Batch: 80 | Curr Loss: 4.3791
 >> Epoch: 27.95 | Batch: 100 | Curr Loss: 4.6723

 EPOCH 28.00 RESULTS:
 Training Loss:   4.5690 | Training Acc:   9.25% (Hard Labels)
 Validation Loss: 3.9250       | Validation Acc: 13.37% (Actual performance)





 >> Epoch: 28.18 | Batch: 20 | Curr Loss: 4.4772
 >> Epoch: 28.38 | Batch: 40 | Curr Loss: 4.1977
 >> Epoch: 28.57 | Batch: 60 | Curr Loss: 4.8222
 >> Epoch: 28.76 | Batch: 80 | Curr Loss: 4.7327
 >> Epoch: 28.95 | Batch: 100 | Curr Loss: 4.9990

 EPOCH 29.00 RESULTS:
 Training Loss:   4.5129 | Training Acc:   8.05% (Hard Labels)
 Validation Loss: 3.8950       | Validation Acc: 13.89% (Actual performance)





 >> Epoch: 29.18 | Batch: 20 | Curr Loss: 4.1156
 >> Epoch: 29.38 | Batch: 40 | Curr Loss: 3.9390
 >> Epoch: 29.57 | Batch: 60 | Curr Loss: 3.9203
 >> Epoch: 29.76 | Batch: 80 | Curr Loss: 4.7210
 >> Epoch: 29.95 | Batch: 100 | Curr Loss: 4.1404

 EPOCH 30.00 RESULTS:
 Training Loss:   4.4793 | Training Acc:   9.50% (Hard Labels)
 Validation Loss: 3.8869       | Validation Acc: 13.72% (Actual performance)



# Test

In [6]:
TEST_DATA_PATH = "processed_bird_test_data"
MODEL_FOLDER = "coatnet_aug_experiment" 
MODEL_NAME = "coatnet_0_rw_224"
BATCH_SIZE = 32

MODEL_PATH = f"{MODEL_FOLDER}/best_aug_model" 


if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def apply_test_transforms(batch):
    batch["pixel_values"] = [test_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

print(f"Loading test data from '{TEST_DATA_PATH}'.")
try:
    test_dataset = load_from_disk(TEST_DATA_PATH)
    if "test" in test_dataset:
        test_dataset = test_dataset["test"]
        
    print(f"Applying transforms to {len(test_dataset)} images.")
    test_dataset = test_dataset.map(apply_test_transforms, batched=True, batch_size=BATCH_SIZE)
    
    test_dataset.set_format(type="torch", columns=["pixel_values", "id"])
    
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    print("Data ready.")
except Exception as e:
    print(f"Error loading data: {e}")
    test_loader = None

if test_loader:
    print(f"Loading best model weights from '{MODEL_PATH}'.")
    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
    
    try:
        state_dict = safetensors.torch.load_file(f"{MODEL_PATH}/model.safetensors", device=device)
        print("Loaded SafeTensors.")
    except FileNotFoundError:
        state_dict = torch.load(f"{MODEL_PATH}/pytorch_model.bin", map_location=device)
        print("Loaded PyTorch Bin.")
        
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    all_preds = []
    all_ids = []

    print("Generating predictions.")
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            pixel_values = batch["pixel_values"].to(device)
            ids = batch["id"]
            
            outputs = model(pixel_values)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_ids.extend(ids.numpy())
            
            if i % 20 == 0:
                print(f"Processing batch {i}/{len(test_loader)}.")

    print("\nCreating CSV.")
    submission_df = pd.DataFrame({
        "id": all_ids,
        "label": all_preds
    })

    submission_df["label"] = submission_df["label"] + 1
    
    submission_df = submission_df.sort_values(by="id")

    csv_filename = "coatnet_augmented_submission.csv"
    submission_df.to_csv(csv_filename, index=False)
    
    print(f"Saved '{csv_filename}' successfully!")
    print("\nFirst 5 rows:")
    print(submission_df.head())

Using device: mps
Loading test data from 'processed_bird_test_data'.
Applying transforms to 4000 images.
Data ready.
Loading best model weights from 'coatnet_aug_experiment/best_aug_model'.
Loaded SafeTensors.
Generating predictions.
Processing batch 0/125.
Processing batch 20/125.
Processing batch 40/125.
Processing batch 60/125.
Processing batch 80/125.
Processing batch 100/125.
Processing batch 120/125.

Creating CSV.
Saved 'coatnet_augmented_submission.csv' successfully!

First 5 rows:
   id  label
0   1     17
1   2     35
2   3     40
3   4     12
4   5     32
