In [1]:
!pip install timm
!pip install ipywidgets
!pip install optuna plotly
!pip install kaleido
!pip install matplotlib



# new model - coatnet
hybrid architecture - Convolution + Attention.

In [2]:
import torch
import optuna
import evaluate
import sys
import shutil
import safetensors.torch
import numpy as np
import pandas as pd
import timm
import matplotlib.pyplot as plt
import safetensors.torch
from torch.utils.data import DataLoader
from optuna.visualization.matplotlib import plot_optimization_history, plot_param_importances
from optuna.visualization import plot_optimization_history, plot_param_importances
from datasets import load_from_disk
from transformers import TrainingArguments, Trainer, set_seed
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, 
    RandomHorizontalFlip, RandomResizedCrop
)

In [3]:
DATA_PATH = "processed_bird_data"
OUTPUT_DIR = "new_model_checkpoints"
MODEL_NAME = "coatnet_0_rw_224"

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

print("Loading and transforming data.")
dataset = load_from_disk(DATA_PATH)

Using device: cpu
Loading and transforming data.


Applying Aggressive Data Augmentation to prevent overfitting:

1. Does random resize/zoom (scale 0.8â€“1.0). Forces the model to recognize a bird by looking at its specific feature (e.g. the head, wing, etc.) instead of the background (e.g. tress).
2. Incorporates horizontal flipping to double the training data.

In [4]:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

_train_transforms = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)), 
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

_val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def train_transforms(batch):
    batch["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

def val_transforms(batch):
    batch["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

# fix nonetype error
dataset["train"] = dataset["train"].map(
    train_transforms, batched=True, remove_columns=["image"]
)
dataset["validation"] = dataset["validation"].map(
    val_transforms, batched=True, remove_columns=["image"]
)

print("Data ready.")

Data ready.


# Training
Initializing CoAtNet model with random weights, no pretraining. 

It uses standard Convolutional layers in the early stages to extract low-level features (edges, textures), and then uses Transformer layers in the final stages to understand the global shape of the bird. Using Bayesian optimisation.

In [7]:
# accuracy metrics
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

class TimmTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # Move inputs to the correct device (GPU/MPS) - CRITICAL FIX
        inputs = self._prepare_inputs(inputs)
        
        with torch.no_grad():
            if "pixel_values" in inputs:
                outputs = model(inputs["pixel_values"])
            else:
                outputs = model(**inputs)
            
            loss = None
            labels = inputs.get("labels")
            
        return (loss, outputs, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        return metrics

In [None]:
# Bayesian optimization with Optuna
def objective(trial):
    # hyperparameters
    learning_rate = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True)
    
    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
    model.to(device)
    
    run_args = TrainingArguments(
        output_dir=f"{OUTPUT_DIR}/trial_{trial.number}", 
        per_device_train_batch_size=32, 
        num_train_epochs=5,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        warmup_ratio=0.1,
        disable_tqdm=True,
        logging_strategy="epoch",
        save_strategy="no",
        eval_strategy="epoch",
        dataloader_num_workers=0,
        remove_unused_columns=False,
        report_to="none"
    )
    
    trainer = TimmTrainer(
        model=model,
        args=run_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        compute_metrics=compute_metrics,
    )
    
    trainer.train()
    
    metrics = trainer.evaluate()
    accuracy = metrics["eval_accuracy"]
    
    trial.report(accuracy, step=5)
    if trial.should_prune():
        raise optuna.TrialPruned()
        
    return accuracy

print("starting optimization.")
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)

print("\n" + "="*40)
print(f"Best accuracy: {study.best_value*100:.2f}%")
print(f"Best learning rate: {study.best_params['lr']:.6f}")
print(f"Best weight decay: {study.best_params['weight_decay']:.6f}")
print("="*40)

[I 2025-12-14 14:56:15,964] A new study created in memory with name: no-name-0b66506a-07e3-431f-8583-50be5c7d8312


starting optimization.




In [None]:
try:
    plt.figure()
    plot_optimization_history(study)
    plt.title("Optimization History")
    plt.tight_layout()
    plt.savefig("optuna_history.png")
    plt.show()
except Exception as e:
    print(f"Skipping history plot: {e}")

try:
    plt.figure()
    plot_param_importances(study)
    plt.title("Hyperparameter Importance")
    plt.tight_layout()
    plt.savefig("optuna_importance.png")
    plt.show()
except Exception as e:
    print(f"Skipping importance plot: {e}")

print("Plots saved successfully (or skipped if invalid).")

In [None]:
set_seed(42)

model = timm.create_model(
    MODEL_NAME, 
    pretrained=False, 
    num_classes=200 
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

class TimmTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        if model.training:
            with torch.no_grad():
                preds = torch.argmax(logits, dim=-1)
                acc = (preds == labels).float().mean().item()
                
                self.epoch_train_loss += loss.item()
                self.epoch_train_acc += acc
                self.epoch_steps += 1
                
                if self.epoch_steps % 20 == 0:
                    current_epoch_float = self.state.epoch if self.state.epoch is not None else 0
                    print(f" >> Epoch: {current_epoch_float:.2f} | Batch: {self.epoch_steps} | Curr Loss: {loss.item():.4f}")
        
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        
        with torch.no_grad():
            if "pixel_values" in inputs:
                outputs = model(inputs["pixel_values"])
            else:
                outputs = model(**inputs)
            
            loss = None
            labels = inputs.get("labels")
            
        return (loss, outputs, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        avg_train_loss = self.epoch_train_loss / self.epoch_steps if self.epoch_steps > 0 else 0
        avg_train_acc = self.epoch_train_acc / self.epoch_steps if self.epoch_steps > 0 else 0
        
        val_loss = metrics.get(f"{metric_key_prefix}_loss", 0.0)
        val_acc = metrics.get(f"{metric_key_prefix}_accuracy", 0.0)
        
        print("\n" + "="*80)
        print(f" Training Loss:   {avg_train_loss:.4f} | Training Acc:   {avg_train_acc*100:.2f}%")
        print(f" Validation Loss: {val_loss:.4f}       | Validation Acc: {val_acc*100:.2f}%")
        print("="*80 + "\n")
        
        # Reset counters
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0
        return metrics

accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=32, 
    num_train_epochs=30,
    learning_rate=5.9e-4,
    weight_decay=0.065,
    warmup_ratio=0.1,
    
    load_best_model_at_end=True,     
    metric_for_best_model="accuracy",
    save_total_limit=1,
    seed=42,
    
    disable_tqdm=True, 
    logging_strategy="epoch", 
    save_strategy="epoch",
    eval_strategy="epoch",
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="none"
)

trainer = TimmTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model("final_new_model")
print("Best model (Seed 42) saved.")

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

class TimmTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return (loss, logits) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            if "pixel_values" in inputs:
                outputs = model(inputs["pixel_values"])
            else:
                raise ValueError(f"Batch is empty! Keys found: {inputs.keys()}")
            
            loss = None
            labels = inputs.get("labels")
            
        return (loss, outputs, labels)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        self.epoch_train_loss = 0.0
        self.epoch_train_acc = 0.0
        self.epoch_steps = 0
        return metrics

# Baseline (seed 42)
print("\n=== Loading Seed 42 Model to Establish Baseline ===")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)

try:
    state_dict = safetensors.torch.load_file("final_new_model/model.safetensors", device=device)
except FileNotFoundError:
    state_dict = safetensors.torch.load_file(f"{OUTPUT_DIR}/final_new_model/model.safetensors", device=device)

model.load_state_dict(state_dict)
model.to(device)

# Evaluate Baseline
eval_trainer = TimmTrainer(
    model=model,
    args=TrainingArguments(
        output_dir="temp_eval", 
        report_to="none", 
        per_device_eval_batch_size=32,
        remove_unused_columns=False 
    ),
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics
)
base_metrics = eval_trainer.evaluate()
global_best_acc = base_metrics["eval_accuracy"]
print(f"Current Best Accuracy (Seed 42): {global_best_acc*100:.2f}%")

# Exoeriment:
other_seeds = [1, 100]

for seed in other_seeds:
    print(f"\n{'='*20} Running Seed {seed} {'='*20}")
    set_seed(seed)

    # Re-init fresh model
    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
    model.to(device)
    
    # Args for this seed
    training_args = TrainingArguments(
        output_dir=f"{OUTPUT_DIR}/seed_{seed}",
        per_device_train_batch_size=32, 
        num_train_epochs=30,
        learning_rate=5.9e-4,     
        weight_decay=0.065,       
        warmup_ratio=0.1,
        load_best_model_at_end=True,     
        metric_for_best_model="accuracy",
        save_total_limit=1,
        seed=seed,
        disable_tqdm=True,
        logging_strategy="epoch", 
        save_strategy="epoch",
        eval_strategy="epoch",
        dataloader_num_workers=0,
        remove_unused_columns=False,
        report_to="none"
    )

    trainer = TimmTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        compute_metrics=compute_metrics,
    )

    trainer.train()
    
    # Check result
    metrics = trainer.evaluate()
    current_acc = metrics["eval_accuracy"]
    print(f"Seed {seed} Accuracy: {current_acc*100:.2f}%")
    
    # Save if better
    if current_acc > global_best_acc:
        print(f" >>> New best found! ({current_acc*100:.2f}% > {global_best_acc*100:.2f}%) Overwriting final model.")
        global_best_acc = current_acc
        trainer.save_model("final_new_model")
    else:
        print(f" >>> Did not beat current best ({global_best_acc*100:.2f}%). Keeping previous model.")

print(f"\nFinal Experiment Finished. Absolute Best Accuracy: {global_best_acc*100:.2f}%")

# Test

In [None]:
TEST_DATA_PATH = "processed_bird_test_data"
MODEL_PATH = "final_new_model"
MODEL_NAME = "coatnet_0_rw_224"
BATCH_SIZE = 32

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def apply_test_transforms(batch):
    batch["pixel_values"] = [test_transforms(image.convert("RGB")) for image in batch["image"]]
    return batch

print(f"Loading test data from '{TEST_DATA_PATH}'.")
try:
    test_dataset = load_from_disk(TEST_DATA_PATH)
    if "test" in test_dataset:
        test_dataset = test_dataset["test"]
        
    print(f"Applying transforms to {len(test_dataset)} images.")
    test_dataset = test_dataset.map(apply_test_transforms, batched=True, batch_size=BATCH_SIZE)
    
    test_dataset.set_format(type="torch", columns=["pixel_values", "id"])
    
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    print("Data ready.")
except Exception as e:
    print(f"Error loading data: {e}")
    test_loader = None

if test_loader:
    print(f"Loading best model weights from '{MODEL_PATH}'.")
    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
    
    try:
        state_dict = safetensors.torch.load_file(f"{MODEL_PATH}/model.safetensors", device=device)
        print("Loaded SafeTensors.")
    except FileNotFoundError:
        state_dict = torch.load(f"{MODEL_PATH}/pytorch_model.bin", map_location=device)
        print("Loaded PyTorch Bin.")
        
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    all_preds = []
    all_ids = []

    print("Generating predictions.")
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            pixel_values = batch["pixel_values"].to(device)
            ids = batch["id"]
            
            outputs = model(pixel_values)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_ids.extend(ids.numpy())
            
            if i % 20 == 0:
                print(f"Processing batch {i}/{len(test_loader)}.")

    print("\nCreating CSV.")
    submission_df = pd.DataFrame({
        "id": all_ids,
        "label": all_preds
    })

    submission_df["label"] = submission_df["label"] + 1
    
    submission_df = submission_df.sort_values(by="id")

    csv_filename = "coatnet_submission.csv"
    submission_df.to_csv(csv_filename, index=False)
    
    print(f"Saved '{csv_filename}' successfully!")
    print("\nFirst 5 rows (Sanity Check):")
    print(submission_df.head())

In [11]:
# subset analysis - added later (without rerunning the rest)

In [15]:
import torch
import numpy as np
import pandas as pd
import timm
import safetensors.torch
from transformers import TrainingArguments, Trainer
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from datasets import load_from_disk

# config
checkpoint_path = "seedcheckpoints/seed_100/checkpoint-2835/model.safetensors"
MODEL_NAME = "coatnet_0_rw_224"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

# loading data
try:
    dataset
except NameError:
    print("Loading dataset from disk...")
    dataset = load_from_disk("processed_bird_data")

# transformations 
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_transforms_func = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def val_transforms(batch):
    batch["pixel_values"] = [val_transforms_func(image.convert("RGB")) for image in batch["image"]]
    return batch

val_dataset = dataset["validation"]
val_dataset = val_dataset.map(val_transforms, batched=True)
val_dataset.set_format(type="torch", columns=["pixel_values", "label"])

# loading the model from a checkpoint (so we dont need to rerun the code)
print(f"Loading model weights from {checkpoint_path}...")
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)
state_dict = safetensors.torch.load_file(checkpoint_path, device=device)
model.load_state_dict(state_dict)
model.to(device)

# trainer 
class TimmTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = inputs.get("pixel_values")
        labels = inputs.get("labels")
        logits = model(pixel_values)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return (loss, logits) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            if "pixel_values" in inputs:
                outputs = model(inputs["pixel_values"])
            else:
                outputs = model(**inputs)
            
            loss = None
            labels = inputs.get("labels")
            
        return (loss, outputs, labels)

eval_trainer = TimmTrainer(
    model=model,
    args=TrainingArguments(output_dir="temp_eval", report_to="none", remove_unused_columns=False),
)

# predictions
predictions = eval_trainer.predict(val_dataset)
preds = np.argmax(predictions.predictions, axis=1)
labels = predictions.label_ids

# extended subset analysis
df_results = pd.DataFrame({"label": labels, "pred": preds})
df_results["correct"] = df_results["label"] == df_results["pred"]

print("\n" + "="*60)
print("GRANULAR CLASS ANALYSIS (4 Groups)")
print("="*60)

subsets = [
    (0, 50, "Group A (50)"),
    (50, 100, "Group B (50)"),
    (100, 150, "Group C (50)"),
    (150, 200, "Group D (50)")
]

accuracies = []
for start, end, name in subsets:
    subset = df_results[(df_results["label"] >= start) & (df_results["label"] < end)]
    acc = subset["correct"].mean()
    accuracies.append(acc)
    print(f"{name}: {acc*100:.2f}%")

print(f"\nMax Accuracy Spread: {(max(accuracies) - min(accuracies))*100:.2f} percentage points")

print("\n" + "="*60)
print("STABILITY ANALYSIS (5 Random Splits)")
print("="*60)

random_accs = []
for i in range(1, 6):
    random_subset = df_results.sample(frac=0.5, random_state=i*42)
    acc = random_subset["correct"].mean()
    random_accs.append(acc)
    print(f"Random Run {i}: {acc*100:.2f}%")

mean_rnd = np.mean(random_accs)
std_rnd = np.std(random_accs)

print(f"\nMean Accuracy: {mean_rnd*100:.2f}%")
print(f"Standard Deviation: {std_rnd*100:.2f}%")

if std_rnd < 0.015:
    print("CONCLUSION: Model is STABLE across data selection.")
else:
    print("CONCLUSION: Variance detected across data selection.")

Using device: cpu
Loading model weights from seedcheckpoints/seed_100/checkpoint-2835/model.safetensors...





GRANULAR CLASS ANALYSIS (4 Groups)
Group A (50): 35.94%
Group B (50): 30.34%
Group C (50): 19.05%
Group D (50): 11.76%

Max Accuracy Spread: 24.18 percentage points

STABILITY ANALYSIS (5 Random Splits)
Random Run 1: 28.57%
Random Run 2: 26.19%
Random Run 3: 29.25%
Random Run 4: 30.27%
Random Run 5: 29.25%

Mean Accuracy: 28.71%
Standard Deviation: 1.37%
CONCLUSION: Model is STABLE across data selection.
