In [None]:
from datasets import load_dataset

ds = load_dataset("yashikota/birds-525-species-image-classification")
# print(ds)

In [None]:
import pandas as pd
from PIL import Image
df = pd.DataFrame(ds['train'])
df.head()

In [None]:
from IPython.display import display
for x in ds['train'].shuffle(seed=231).select(range(5)):
  display(x["image"])
  print("Label: ", ds['train'].features['label'].int2str(x['label']))

In [None]:
#Checking Label and Image for an American Wigeon
display(ds["train"][3924]["image"])
display(ds['train'].features["label"].int2str(ds['train'][3924]['label']))

In [None]:
#Displaying Features to see our classes
display(ds['train'].features)

In [None]:
# Checking for corrupted images by image link is None
bad_images = []

for i in range(len(ds['train'])):
  image = ds['train'][i]['image']
  if image is None:
    bad_images.append(i)

print(f"Total number of corrupt/null images: {len(bad_images)}")
print("First Bad Ones: ", bad_images[:25])

In [None]:
# Checking for bird images that are not the same size as the defaulted 224x224
size_unmatched = []
for i in range(len(ds['train'])):
  width, height = ds['train'][i]['image'].size
  if width != 224 or height != 224:
    size_unmatched.append(i)
print(f"Amount of images not 224x224:{len(size_unmatched)}")
print(size_unmatched[:20])

In [None]:
# Understanding size differences by checking a few of the images not our default size
for idx in size_unmatched[:10]:
  display(ds['train'][idx]['image'])

In [None]:
# Resizing Images using Hugging Face Map and PILLOW Resampling
def resize_image(examples):
  image = examples["image"]
  if image.size != (224,224):
    image = image.resize((224,224), Image.Resampling.BILINEAR)
  examples["image"] = image
  return examples

ds = ds.map(resize_image)

size_unmatched = []
for i in range(len(ds['train'])):
  width, height = ds['train'][i]['image'].size
  if width != 224 or height != 224:
    size_unmatched.append(i)
print(f"Amount of images not 224x224:{len(size_unmatched)}")
print(size_unmatched[:20])

In [None]:
!pip install faiss-cpu datasets transformers torch torchvision tqdm

import numpy as np
import torch
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
from collections import defaultdict
import faiss
import os

In [None]:
ds = load_dataset("yashikota/birds-525-species-image-classification", split="train")
print(f"Dataset loaded: {len(ds)} images")
# Running on GPU is much faster than CPU
# checks if GPU is available if yes then it switches to GPU for computation and then stores the embedding in CPU ( used for storage)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)#Using a pretrained CLIP model by OpenAi
#CLIP looks at the image and produces a vector like it gives a long list of numbers for that image
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

#this function uses the CLIP model to convert images into numbers in batches
def get_embeddings_batched(images, batch_size=32):
    all_embeds = [] #used stores numerical representations for all images
    for i in tqdm(range(0, len(images), batch_size), desc="Embedding batches"):
        batch = images[i:i+batch_size]#images are processed in batches
        inputs = processor(images=batch, return_tensors="pt").to(device)#uses CLIP processor to turn the images into Pytorch tensors
        with torch.no_grad():
            embeds = model.get_image_features(**inputs)# converts images into vectors
            embeds = embeds / embeds.norm(dim=-1, keepdim=True)# makes the vector length of each vector to 1
        all_embeds.append(embeds.cpu())# moves embedding back to cpu to save
    return torch.cat(all_embeds, dim=0).numpy().astype("float32")

embeddings = get_embeddings_batched(ds["image"], batch_size=64)
print(f"Embeddings shape: {embeddings.shape}")

# here comparing the embeddings and checking for similarity
dim = embeddings.shape[1]# gets the size of each embedding vector
index = faiss.IndexFlatIP(dim)# index measures similarity by dot product
index.add(embeddings)# puts all embeddings into the index
print(f"FAISS index built with {index.ntotal} embeddings.")

#finds k nearest neighbors for every image
k = 5 # for every image 5 neighbours
D, I = index.search(embeddings, k)# D = similarity scores, I = neighbor indices
threshold = 0.92

#filtering pairs and keeping only neighbors with similarity above 0.95
pairs = np.argwhere(D[:,1:] > threshold)
graph = defaultdict(set)
#here building a graph of similar images where a node is image and an edge connects two images if they are very similar which is above threshold
for row, col in pairs:
    neighbor = I[row, col]
    graph[row].add(neighbor)
    graph[neighbor].add(row)

#Grouping duplicates
visited = set()
components = []#list of all duplicate groups
for node in range(len(embeddings)):
  #goes through all images  and traverse the graph using stack (DPS)
    if node not in visited and graph[node]:
        stack = [node]
        comp = []#stores duplicate image pair
        while stack:
            n = stack.pop()
            if n not in visited:
                visited.add(n)
                comp.append(n)
                stack.extend(graph[n])# visit all connected neighbors
        if len(comp) > 1:
            components.append(comp) #if group has duplicates saves it

print(f"Found {len(components)} duplicate groups")

# Keep the first image of each duplicate group, remove others
remove = {i for comp in components for i in comp[1:]}
keep = [i for i in range(len(ds)) if i not in remove]

#Two new datasets one clean and one with duplicates
cleaned_ds = ds.select(keep)
duplicates_ds = ds.select(list(remove))

#Saving the dataset into disk so we can access it later
os.makedirs("cleaned_data", exist_ok=True)#creates a folder cleaned_data to store the datasets
cleaned_path = "cleaned_data/cleaned_birds"
dup_path = "cleaned_data/duplicates_birds"
cleaned_ds.save_to_disk(cleaned_path)
duplicates_ds.save_to_disk(dup_path)

print("\nCleaning done")
print(f"Original size: {len(ds)}")
print(f"Cleaned dataset size: {len(cleaned_ds)}")
print(f"Duplicates removed: {len(duplicates_ds)}")
print(f"Percentage removed: {100*len(duplicates_ds)/len(ds):.2f}%")

In [None]:
from datasets import load_from_disk
from PIL import Image
import matplotlib.pyplot as plt

# loads duplicates dataset
duplicates_ds = load_from_disk("cleaned_data/duplicates_birds")

# displays first 20 images
num_images = min(20, len(duplicates_ds))
plt.figure(figsize=(15, 12))

for i in range(num_images):
    img = duplicates_ds[i]["image"]
    plt.subplot(4, 5, i+1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Label: {duplicates_ds[i].get('label', 'N/A')}")

plt.tight_layout()
plt.show()

In [None]:
from datasets import load_from_disk
from PIL import Image
import matplotlib.pyplot as plt

# loads cleaned dataset
duplicates_ds = load_from_disk("cleaned_data/cleaned_birds")

# displays first 20 images
num_images = min(20, len(duplicates_ds))
plt.figure(figsize=(15, 12))

for i in range(num_images):
    img = duplicates_ds[i]["image"]
    plt.subplot(4, 5, i+1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Label: {duplicates_ds[i].get('label', 'N/A')}")

plt.tight_layout()
plt.show()

In [None]:
#class imbalances
#making a dictionary of species names and how many pictures there are of each species to filter through later
#only done for training data since we are just training the model

#print(ds['train'].features) #dictionary of names from the label
from collections import Counter #counts how many there are of each label
values=ds["train"]["label"]
counts=Counter(values)
#print(counts) #how much there are of each value, key associated with a certain species
labels=ds["train"].features["label"].names #getting the species names
countsSpecies={labels[i]: c for i, c in counts.items()} #makes a dictionary of the species and how many images there are in alphabetical order
#print(countsSpecies)

In [None]:
#assigning weights to different classes, best for datasets with medium levels of imbalance (eg. here)
#using cross entropy to have the model weigh classes with less samples more
#undersampling is not preferred due to the risk of permanently removing important data from the dataset

#for PyTorch
import torch

classCounts=torch.tensor([counts[i] for i in range(len(labels))], dtype=torch.float) #creates a list of how many there are per species alphabetically

n=classCounts.sum() #formula for computing the class weights
c=len(classCounts)
classWeights=n/(c*classCounts)
#print(classWeights)

In [None]:
#defining the loss function with these weights to properly evaluate the accuracy after training
import torch.nn as nn
criterion=nn.CrossEntropyLoss(weight=classWeights)
#all of the above goes before creating the model and training it
#with this, running the loss function weighs mistakes more heavily on the species with less pictures than speices with more pictures

In [None]:
# EfficientNet model training starts here
!pip install timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timm
from tqdm import tqdm
import numpy as np
from torchvision import transforms

In [None]:
# fix class weights (class 381 is empty)

from collections import Counter

labeles = ds['train'].features['label'].names
label_values = ds['train']['label']
counts = Counter(label_values)

classCounts = torch.tensor([counts.get(i,0) for i in range(len(labels))], dtype=torch.float)
print(f'original classcounts: {classCounts[381]}')

classCounts[classCounts == 0] = 1
print(f'fixed classcounts: {classCounts[381]}')

n = classCounts.sum()
c = len(labels)
classWeights = n / (c * classCounts)

print(f'Any infinite weights: {torch.isinf(classWeights).any()}')

In [None]:
# filter out class 381 completely (has 0 training images, but has images in validation set)

print("Original dataset sizes:")
print(f"  Train: {len(ds['train'])}")
print(f"  Validation: {len(ds['validation'])}")
print(f"  Test: {len(ds['test'])}")

# Filter out class 381 from all splits
ds['train'] = ds['train'].filter(lambda x: x['label'] != 381)
ds['validation'] = ds['validation'].filter(lambda x: x['label'] != 381)
ds['test'] = ds['test'].filter(lambda x: x['label'] != 381)

print("\nAfter removing PARAKEET AUKLET (class 381):")
print(f"  Train: {len(ds['train'])} (-{84635 - len(ds['train'])} images)")
print(f"  Validation: {len(ds['validation'])} (-{2625 - len(ds['validation'])} images)")
print(f"  Test: {len(ds['test'])} (-{2625 - len(ds['test'])} images)")
print("\n✓ Class 381 removed from all splits")

In [None]:
# convert dataset to pytorch format

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], # use predefined imagenet values
        std=[0.229, 0.224, 0.225]
    )
])

# custom class
class BirdDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform=None):
      self.dataset = hf_dataset
      self.transform = transform

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self, idx):
      item = self.dataset[idx]
      image = item['image']
      label = item['label']
      if self.transform:
        image = self.transform(image)
      return image, label

train_dataset = BirdDataset(ds['train'], transform=transform)
val_dataset = BirdDataset(ds['validation'], transform=transform)
test_dataset = BirdDataset(ds['test'], transform=transform)

print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

In [None]:
# create dataloaders (help us batch process and randomize order)

BATCH_SIZE  = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)
print(f"Number of training batches: {len(train_loader)}")

In [None]:
# load model

model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=526)
device = torch.device('cuda')
model = model.to(device)
print(f"Using device: {device}")
print(f"Model loaded with {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
# set up training weights

classWeights = classWeights.to(device)
criterion = nn.CrossEntropyLoss(weight=classWeights)

optimizer = optim.Adam(model.parameters(), lr=0.001)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
)

In [None]:
# training definitions

def train_one_epoch(model, train_loader, criterion, optimizer, device):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  for images, labels in tqdm(train_loader, desc="training"):
      images, labels = images.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)

      loss.backward()
      optimizer.step()

      running_loss += loss.item()
      _, predicted = outputs.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

  epoch_loss = running_loss/len(train_loader)
  epoch_acc = 100. * correct / total
  return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0

  with torch.no_grad():
      for images, labels in tqdm(val_loader, desc="validation"):
          images, labels = images.to(device), labels.to(device)
          outputs = model(images)
          loss = criterion(outputs, labels)
          running_loss += loss.item()
          _, predicted = outputs.max(1)
          total += labels.size(0)
          correct += predicted.eq(labels).sum().item()

      epoch_loss = running_loss / len(val_loader)
      epoch_acc = 100. * correct / total
      return epoch_loss, epoch_acc



In [None]:
# run training

NUM_EPOCHS = 10
best_val_acc = 0.0
print('Starting training: \n')
for epoch in range(NUM_EPOCHS):
  print(f"Epoch {epoch+1}/{NUM_EPOCHS}")

  train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
  print(f"Train loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

  val_loss, val_acc = validate(model, val_loader, criterion, device)
  print(f"Val loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

  scheduler.step(val_loss)

  if val_acc > best_val_acc:
    best_val_acc = val_acc
    torch.save(model.state_dict(), 'best_efficientnet_b0_birds.pth')
    print(f"Training complete, best val accuracy: {best_val_acc:.2f}%")

In [None]:
# analyze results

from collections import Counter
from tqdm import tqdm

def analyze_confusion(dataset, top_n=15):
    # find which species pairs get confused most often

    print("Analyzing errors on validation set...")
    model.eval()

    confusion_pairs = []

    with torch.no_grad():
        for idx in tqdm(range(len(dataset))):
            image, true_label = dataset[idx]
            image = image.unsqueeze(0).to(device)

            output = model(image)
            _, predicted = output.max(1)
            pred_label = predicted.item()

            # only track errors
            if pred_label != true_label:
                confusion_pairs.append((true_label, pred_label))

    # count most common confusions
    confusion_counts = Counter(confusion_pairs)
    top_confusions = confusion_counts.most_common(top_n)

    print(f"\n{'='*80}")
    print(f"TOP {top_n} MOST CONFUSED SPECIES PAIRS")
    print(f"{'='*80}\n")

    for i, ((true_idx, pred_idx), count) in enumerate(top_confusions, 1):
        true_name = labels[true_idx]
        pred_name = labels[pred_idx]
        print(f"{i}. {count} times: '{true_name}' → predicted as → '{pred_name}'")

    return top_confusions

top_errors = analyze_confusion(val_dataset, top_n=15)

In [None]:
# visually show species that were confused

def show_top_confusions_comparison(dataset, top_errors, num_pairs=3):

    fig, axes = plt.subplots(num_pairs, 2, figsize=(12, 5*num_pairs))

    for pair_idx in range(num_pairs):
        true_idx, pred_idx = top_errors[pair_idx][0]
        count = top_errors[pair_idx][1]

        true_name = labels[true_idx]
        pred_name = labels[pred_idx]

        print(f"\n{'='*60}")
        print(f"CONFUSION #{pair_idx+1}: {count} times")
        print(f"TRUTH: {true_name}")
        print(f"MODEL PREDICTED: {pred_name}")
        print(f"{'='*60}")

        # find one example of this confusion
        found = False
        model.eval()
        with torch.no_grad():
            for idx in range(len(dataset)):
                image, label = dataset[idx]
                if label == true_idx:
                    image_batch = image.unsqueeze(0).to(device)
                    output = model(image_batch)
                    _, predicted = output.max(1)

                    if predicted.item() == pred_idx:
                        confused_image = dataset.dataset[idx]['image']
                        found = True
                        break

        if not found:
            print("No example found for this pair")
            continue

        # LEFT: Show what it actually is (with a correct example)
        ax_left = axes[pair_idx, 0] if num_pairs > 1 else axes[0]

        # find a correct example of the true species
        correct_example = None
        with torch.no_grad():
            for idx in range(len(dataset)):
                image, label = dataset[idx]
                if label == true_idx:
                    image_batch = image.unsqueeze(0).to(device)
                    output = model(image_batch)
                    _, predicted = output.max(1)

                    if predicted.item() == true_idx:
                        correct_example = dataset.dataset[idx]['image']
                        break

        if correct_example is not None:
            ax_left.imshow(correct_example)
        else:
            ax_left.imshow(confused_image)

        ax_left.axis('off')
        ax_left.set_title(f"✓ {true_name}\n(should've guessed)",
                         fontsize=12, color='green', weight='bold', pad=10)
        ax_left.spines['top'].set_color('green')
        ax_left.spines['bottom'].set_color('green')
        ax_left.spines['left'].set_color('green')
        ax_left.spines['right'].set_color('green')
        ax_left.spines['top'].set_linewidth(5)
        ax_left.spines['bottom'].set_linewidth(5)
        ax_left.spines['left'].set_linewidth(5)
        ax_left.spines['right'].set_linewidth(5)

        # RIGHT: Show what model wrongly thinks it is
        ax_right = axes[pair_idx, 1] if num_pairs > 1 else axes[1]
        ax_right.imshow(confused_image)
        ax_right.axis('off')
        ax_right.set_title(f"✗ MODEL SAYS: {pred_name}\n(wrong guess)",
                          fontsize=12, color='red', weight='bold', pad=10)
        ax_right.spines['top'].set_color('red')
        ax_right.spines['bottom'].set_color('red')
        ax_right.spines['left'].set_color('red')
        ax_right.spines['right'].set_color('red')
        ax_right.spines['top'].set_linewidth(5)
        ax_right.spines['bottom'].set_linewidth(5)
        ax_right.spines['left'].set_linewidth(5)
        ax_right.spines['right'].set_linewidth(5)

    plt.tight_layout()
    plt.savefig('top_3_confusions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\n saved as 'top_3_confusions.png'")

show_top_confusions_comparison(val_dataset, top_errors, num_pairs=3)