In [8]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torchvision.models import resnet50

In [9]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
#get dataset location
data_dir = "images"

In [11]:
# Enhanced data augmentation
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [12]:
# Load dataset
dataset = ImageFolder(root=data_dir, transform=train_transforms)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
val_dataset.dataset.transform = val_transforms

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [13]:
from torchvision.models import resnet50
import torch.nn as nn

# Load pretrained model
model = resnet50(pretrained=True).to(device)  # Move to CUDA 


model.fc = nn.Linear(model.fc.in_features, 202).to(device)  # Move this new layer to the same device

# Step 3: Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)  
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)  # Reduce LR every 5 epochs



In [14]:
# Step 4: Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        scheduler.step()  
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")

        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Validation Accuracy: {accuracy:.2f}%")

In [15]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=12)

Epoch 1/12, Loss: 3.3004
Validation Accuracy: 55.13%
Epoch 2/12, Loss: 1.1947
Validation Accuracy: 70.19%
Epoch 3/12, Loss: 0.5371
Validation Accuracy: 74.98%
Epoch 4/12, Loss: 0.2265
Validation Accuracy: 80.41%
Epoch 5/12, Loss: 0.1617
Validation Accuracy: 80.79%
Epoch 6/12, Loss: 0.1258
Validation Accuracy: 81.59%
Epoch 7/12, Loss: 0.1053
Validation Accuracy: 81.68%
Epoch 8/12, Loss: 0.1019
Validation Accuracy: 81.04%
Epoch 9/12, Loss: 0.0990
Validation Accuracy: 81.55%
Epoch 10/12, Loss: 0.0964
Validation Accuracy: 81.55%
Epoch 11/12, Loss: 0.0944
Validation Accuracy: 81.26%
Epoch 12/12, Loss: 0.0947
Validation Accuracy: 81.47%


In [16]:
torch.save(model, "bird_model_v2.pt")