In [19]:
!pip install transformers datasets accelerate



In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from datasets import load_from_disk
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import os

In [21]:
DATA_PATH = "processed_bird_data"
TEST_DATA_PATH = "processed_bird_test_data"
MODEL_NAME = "google/mobilenet_v2_1.0_224"
BATCH_SIZE = 32 
EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Using device: cpu


In [22]:
print("Loading data...")
try:
    dataset = load_from_disk(DATA_PATH)
except FileNotFoundError:
    print(f"Error: {DATA_PATH} not found.")
    raise

Loading data...


In [23]:
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)

In [24]:
def transform(batch):
    inputs = feature_extractor([x for x in batch["image"]], return_tensors="pt")
    inputs["label"] = batch["label"]
    return inputs

dataset = dataset.with_transform(transform)


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(dataset["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")

Train batches: 105
Val batches:   19


In [25]:
print("Initializing Baseline Model (MobileNetV2)...")
model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=200,
    ignore_mismatched_sizes=True
)
model.to(DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss()

print("Model ready.")

Initializing Baseline Model (MobileNetV2)...


Some weights of MobileNetV2ForImageClassification were not initialized from the model checkpoint at google/mobilenet_v2_1.0_224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1001]) in the checkpoint and torch.Size([200]) in the model instantiated
- classifier.weight: found shape torch.Size([1001, 1280]) in the checkpoint and torch.Size([200, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model ready.


In [11]:
best_val_acc = 0.0
save_path = "baseline_best_model.pth"

print("Starting training loop...\n")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for i, batch in enumerate(train_loader):
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        
        optimizer.zero_grad()
        
        # forward pass 
        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        # stats
        running_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        
        if i % 20 == 0:
            print(f"[Epoch {epoch+1}] Batch {i}/{len(train_loader)} | loss={loss.item():.4f}")

    train_epoch_loss = running_loss / len(train_loader)
    train_epoch_acc = correct_train / total_train
    
    # val
    model.eval()
    val_running_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            
            outputs = model(pixel_values=pixel_values)
            loss = criterion(outputs.logits, labels)
            
            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
            
    val_epoch_loss = val_running_loss / len(val_loader)
    val_epoch_acc = correct_val / total_val
    
    # final raport (per epoch)
    print(f"Train: loss={train_epoch_loss:.4f}, acc={train_epoch_acc:.4f}")
    print(f"Val:   loss={val_epoch_loss:.4f}, acc={val_epoch_acc:.4f}")
    
    # saving the best model
    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        torch.save(model.state_dict(), save_path)
        print("Best model saved!")
    
    print("-" * 30)

print(f"\nTraining complete. Best Validation Accuracy: {best_val_acc:.4f}")

Starting training loop...

[Epoch 1] Batch 0/105 | loss=5.2987
[Epoch 1] Batch 20/105 | loss=5.0515
[Epoch 1] Batch 40/105 | loss=4.9252
[Epoch 1] Batch 60/105 | loss=4.1871
[Epoch 1] Batch 80/105 | loss=3.5144
[Epoch 1] Batch 100/105 | loss=3.7427
Train: loss=4.4456, acc=0.1417
Val:   loss=3.6121, acc=0.2683
Best model saved!
------------------------------
[Epoch 2] Batch 0/105 | loss=3.0475
[Epoch 2] Batch 20/105 | loss=2.8099
[Epoch 2] Batch 40/105 | loss=2.1361
[Epoch 2] Batch 60/105 | loss=2.4891
[Epoch 2] Batch 80/105 | loss=2.3831
[Epoch 2] Batch 100/105 | loss=1.7832
Train: loss=2.2788, acc=0.5391
Val:   loss=2.4471, acc=0.4177
Best model saved!
------------------------------
[Epoch 3] Batch 0/105 | loss=1.0102
[Epoch 3] Batch 20/105 | loss=1.5037
[Epoch 3] Batch 40/105 | loss=1.3224
[Epoch 3] Batch 60/105 | loss=1.2309
[Epoch 3] Batch 80/105 | loss=1.0523
[Epoch 3] Batch 100/105 | loss=1.1437
Train: loss=1.2398, acc=0.7675
Val:   loss=2.1713, acc=0.4533
Best model saved!
-----

In [14]:
print("\n Now, running on the test set")

if os.path.exists(TEST_DATA_PATH):
    print(f"Loading test data from {TEST_DATA_PATH}...")
    try:
        test_dataset_raw = load_from_disk(TEST_DATA_PATH)
        if isinstance(test_dataset_raw, dict) and "test" in test_dataset_raw:
            test_ds = test_dataset_raw["test"]
        else:
            test_ds = test_dataset_raw
            
        test_ds = test_ds.with_transform(transform)
        test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
        
    except Exception as e:
        print(f"Error loading test data: {e}")
        test_loader = None
else:
    print("Test data folder not found.")
    test_loader = None

if test_loader:
    print(f"Loading best weights from {save_path}...")
    model.load_state_dict(torch.load(save_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    
    print("Testing...")
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            
            outputs = model(pixel_values=pixel_values)
            loss = criterion(outputs.logits, labels)
            
            test_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
            
            if i % 20 == 0:
                 print(f"[Test] Batch {i}/{len(test_loader)}")

    final_loss = test_loss / len(test_loader)
    final_acc = correct_test / total_test
    
    print("\n" + "="*30)
    print("FINAL TEST RESULTS:")
    print(f"Test Loss:     {final_loss:.4f}")
    print(f"Test Accuracy: {final_acc:.4f} ({final_acc*100:.2f}%)")
    print("="*30)


 Now, running on the test set
Loading test data from processed_bird_test_data...
Loading best weights from baseline_best_model.pth...
Testing...
[Test] Batch 0/125
[Test] Batch 20/125
[Test] Batch 40/125
[Test] Batch 60/125
[Test] Batch 80/125
[Test] Batch 100/125
[Test] Batch 120/125

FINAL TEST RESULTS:
Test Loss:     8.0590
Test Accuracy: 0.0030 (0.30%)


In [15]:
# ? blind 

In [None]:
# running the below code after the 1st part finished running (not optimized)

In [28]:
import pandas as pd
import torch
import os
from datasets import load_from_disk

save_path = "baseline_best_model.pth"
TEST_DATA_PATH = "processed_bird_test_data"

all_preds = []

if os.path.exists(save_path):
    print(f"Loading best weights from {save_path}...")
    model.load_state_dict(torch.load(save_path, map_location=DEVICE))
else:
    print(f"WARNING: File '{save_path}' not found! Using current model weights.")

model.eval()
model.to(DEVICE)

print(f"Predicting classes for test images...")
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        pixel_values = batch["pixel_values"].to(DEVICE)
        
        outputs = model(pixel_values=pixel_values)
        preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
        all_preds.extend(preds)
        
        if i % 20 == 0:
            print(f"[Submission] Processing batch {i}/{len(test_loader)}")

print("Retrieving IDs from disk...")

clean_test_dataset = load_from_disk(TEST_DATA_PATH)

if isinstance(clean_test_dataset, dict) and "test" in clean_test_dataset:
    clean_ds = clean_test_dataset["test"]
else:
    clean_ds = clean_test_dataset

if "id" in clean_ds.column_names:
    submission_ids = clean_ds["id"]
else:
    print("Warning: 'id' column missing in raw data. Generating index.")
    submission_ids = range(len(all_preds))

if len(submission_ids) != len(all_preds):
    print(f"ERROR: ID count ({len(submission_ids)}) != Prediction count ({len(all_preds)})")
else:
    submission_df = pd.DataFrame({
        "id": submission_ids,
        "label": all_preds
    })

    csv_filename = "baseline_submission.csv"
    submission_df.to_csv(csv_filename, index=False)

Loading best weights from baseline_best_model.pth...
Predicting classes for test images...
[Submission] Processing batch 0/125
[Submission] Processing batch 20/125
[Submission] Processing batch 40/125
[Submission] Processing batch 60/125
[Submission] Processing batch 80/125
[Submission] Processing batch 100/125
[Submission] Processing batch 120/125
Retrieving IDs from disk...
