In [1]:
!pip install torchinfo



# Error Analysis

1. Computational Complexity

In [2]:
import torch
import torch.nn as nn
import timm
import os
import shutil
from transformers import AutoModelForImageClassification
from torchinfo import summary
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from datasets import load_from_disk
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch.utils.data import DataLoader

In [3]:
# We dont need the trained models, just look at the architectures
# COPIED from the cnn scratch model:
class BirdCNN(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        x = self.features(x).view(x.size(0), -1)
        return self.classifier(x)

# ResNet model
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.downsample = None
        if stride != 1 or in_ch != out_ch:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride),
                nn.BatchNorm2d(out_ch)
            )
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample: identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class ResNetScratch(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 7, stride=2, padding=3),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
        )
        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 256, stride=2)
        self.layer4 = BasicBlock(256, 256, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)
    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

# ViT MAE model
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size//patch_size
        self.num_patches = self.grid_size**2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim), nn.Dropout(drop),
        )
    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x

class SimpleViTWithAttributes(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200, num_attr=312, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
        self.blocks = nn.ModuleList([TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, drop) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head_class = nn.Linear(embed_dim, num_classes)
        self.head_attr  = nn.Linear(embed_dim, num_attr)
    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B,-1,-1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        return self.head_class(cls), self.head_attr(cls)

device = "cpu"

# Simple CNN
cnn_model = BirdCNN(num_classes=200)

# Our hugging face baseline:
baseline_model = AutoModelForImageClassification.from_pretrained(
    "google/mobilenet_v2_1.0_224", num_labels=200, ignore_mismatched_sizes=True
)

# Our main model (coatnet, same for pre/post training and augmentation versions)
coatnet_model = timm.create_model("coatnet_0_rw_224", pretrained=False, num_classes=200)

# ResNet
resnet_model = ResNetScratch(num_classes=200)

# ConvNeXt
convnext_model = timm.create_model("convnext_tiny", pretrained=False, num_classes=200)

# ViT MAE
vit_model = SimpleViTWithAttributes(img_size=224, patch_size=16, in_chans=3, num_classes=200, num_attr=312, embed_dim=192, depth=6, num_heads=3)

def get_stats(model, model_name):
    stats = summary(model, input_size=(1, 3, 224, 224), verbose=0)
    params = stats.total_params
    # FLOPPPPP
    flops = stats.total_mult_adds
    
    print(f"--- {model_name} ---")
    print(f"Parameters: {params:,}")
    print(f"FLOPs (Approx): {flops:,}")
    print(f"Size (MB): {stats.to_megabytes(params):.2f}")
    return params, flops

print("COMPUTATIONAL COMPLEXITY ANALYSIS\n")
cnn_p, cnn_f = get_stats(cnn_model, "Simple CNN")
base_p, base_f = get_stats(baseline_model, "Baseline (MobileNetV2)")
coat_p, coat_f = get_stats(coatnet_model, "CoAtNet (Hybrid/Aug/Pre-tune)")
res_p, res_f = get_stats(resnet_model, "ResNet (CNN)")
conv_p, conv_f = get_stats(convnext_model, "CoNeXt (Hybrid)")
vit_p, vit_f = get_stats(vit_model, "ViT MAE (Transformer)")

print("\n--- COMPARISON ---")
print(f"CoAtNet vs CNN Params: {coat_p / cnn_p:.1f}x larger")
print(f"CoAtNet vs Baseline Params: {coat_p / base_p:.1f}x larger")

Some weights of MobileNetV2ForImageClassification were not initialized from the model checkpoint at google/mobilenet_v2_1.0_224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1001]) in the checkpoint and torch.Size([200]) in the model instantiated
- classifier.weight: found shape torch.Size([1001, 1280]) in the checkpoint and torch.Size([200, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


COMPUTATIONAL COMPLEXITY ANALYSIS

--- Simple CNN ---
Parameters: 508,104
FLOPs (Approx): 817,184,136
Size (MB): 0.51
--- Baseline (MobileNetV2) ---
Parameters: 2,480,072
FLOPs (Approx): 299,784,584
Size (MB): 2.48
--- CoAtNet (Hybrid/Aug/Pre-tune) ---
Parameters: 26,820,362
FLOPs (Approx): 4,213,652,296
Size (MB): 26.82
--- ResNet (CNN) ---
Parameters: 2,511,944
FLOPs (Approx): 214,535,816
Size (MB): 2.51
--- CoNeXt (Hybrid) ---
Parameters: 27,973,928
FLOPs (Approx): 321,756,392
Size (MB): 27.97
--- ViT MAE (Transformer) ---
Parameters: 2,954,048
FLOPs (Approx): 30,818,048
Size (MB): 2.95

--- COMPARISON ---
CoAtNet vs CNN Params: 52.8x larger
CoAtNet vs Baseline Params: 10.8x larger


2. Confusion Matrix

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "processed_bird_data"
MODEL_PATH = "final_new_model/model.safetensors"

# Load data
dataset = load_from_disk(DATA_PATH)
val_ds = dataset["validation"]

# Transforms
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_transforms = Compose([Resize(256), CenterCrop(224), ToTensor(), normalize])

def transform_fn(batch):
    batch["pixel_values"] = [val_transforms(img.convert("RGB")) for img in batch["image"]]
    return batch

val_ds.set_transform(transform_fn)
val_loader = DataLoader(val_ds, batch_size=32, collate_fn=lambda x: {
    "pixel_values": torch.stack([i["pixel_values"] for i in x]), 
    "labels": torch.tensor([i["label"] for i in x])
})

# Load Model
model = timm.create_model("coatnet_0_rw_224", pretrained=False, num_classes=200)
from safetensors.torch import load_file
model.load_state_dict(load_file(MODEL_PATH))
model.to(device)
model.eval()

true_labels = []
pred_labels = []
confidences = []
images_for_plot = []

print("Running Inference for Error Analysis.")
with torch.no_grad():
    for batch in val_loader:
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        
        # Get max prob and predicted class
        max_probs, preds = torch.max(probs, dim=1)
        
        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
        confidences.extend(max_probs.cpu().numpy())

print("Done.")

Running Inference for Error Analysis.
Done.


First right and wrong predictions:

In [5]:
raw_val_ds = load_from_disk(DATA_PATH)["validation"]

df_results = pd.DataFrame({
    "True": true_labels,
    "Pred": pred_labels,
    "Confidence": confidences,
    "Index": range(len(true_labels))
})

# flops
wrong_preds = df_results[df_results["True"] != df_results["Pred"]]
top_wrong = wrong_preds.sort_values(by="Confidence", ascending=False).head(3)

# successes
correct_preds = df_results[df_results["True"] == df_results["Pred"]]
top_correct = correct_preds.sort_values(by="Confidence", ascending=False).head(3)

print("Top 3 Failures (Use 2 for Poster):")
print(top_wrong)
print("\nTop 3 Successes (Use 2 for Poster):")
print(top_correct)

output_dir = "poster_images"
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

print(f"\nSaving images to folder: {output_dir}...")

def save_crops(df, prefix):
    for i, (_, row) in enumerate(df.iterrows()):
        idx = int(row["Index"])
        true_cls = int(row["True"])
        pred_cls = int(row["Pred"])
        conf = row["Confidence"]
        
        img = raw_val_ds[idx]["image"]
        fname = f"{output_dir}/{prefix}_{i+1}_true{true_cls}_pred{pred_cls}.png"
        img.save(fname)
        print(f"Saved: {fname} (Conf: {conf:.4f})")

save_crops(top_wrong, "fail")
save_crops(top_correct, "success")

Top 3 Failures (Use 2 for Poster):
     True  Pred  Confidence  Index
94     68    36    0.999438     94
284    43     9    0.998038    284
14     26    28    0.997386     14

Top 3 Successes (Use 2 for Poster):
     True  Pred  Confidence  Index
335    11    11    0.999415    335
126    33    33    0.999210    126
58      6     6    0.999071     58

Saving images to folder: poster_images...
Saved: poster_images/fail_1_true68_pred36.png (Conf: 0.9994)
Saved: poster_images/fail_2_true43_pred9.png (Conf: 0.9980)
Saved: poster_images/fail_3_true26_pred28.png (Conf: 0.9974)
Saved: poster_images/success_1_true11_pred11.png (Conf: 0.9994)
Saved: poster_images/success_2_true33_pred33.png (Conf: 0.9992)
Saved: poster_images/success_3_true6_pred6.png (Conf: 0.9991)


Now actual confusion matrix (200x200 is hard to had to get creative), so its better to show which pairs of birds are confused most often:

In [6]:
cm = confusion_matrix(true_labels, pred_labels)

np.fill_diagonal(cm, 0)

# Find max confusion
flat_indices = np.argsort(cm.flatten())[-5:] # Top 5 confused pairs
rows, cols = np.unravel_index(flat_indices, cm.shape)

print("\nTop 5 Most Confused Pairs (Model predicts B when it is actually A):")
for r, c in zip(rows, cols):
    print(f"True Class {r} -> Predicted as Class {c} (Count: {cm[r, c]})")


Top 5 Most Confused Pairs (Model predicts B when it is actually A):
True Class 28 -> Predicted as Class 8 (Count: 2)
True Class 64 -> Predicted as Class 61 (Count: 2)
True Class 78 -> Predicted as Class 92 (Count: 2)
True Class 97 -> Predicted as Class 11 (Count: 2)
True Class 55 -> Predicted as Class 34 (Count: 3)
