In [None]:
# !pip install torch torchvision tqdm

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
from google.colab import drive

# Check for GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Print GPU info
!nvidia-smi

# Mount Google Drive
drive.mount('/content/drive')

# Define a directory in your Google Drive to save checkpoints
checkpoint_dir = '/content/drive/MyDrive/bird_classification_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)


# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load datasets
train_dataset = ImageFolder(root='/content/Upload/Upload/train', transform=transform)
val_dataset = ImageFolder(root='/content/Upload/Upload/valid', transform=transform)
test_dataset = ImageFolder(root='/content/Upload/Upload/test', transform=transform)

# Create data loaders
batch_size = 128  # Increased for A100
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

num_classes = len(train_dataset.classes)
print(f"Number of classes: {num_classes}")

def create_model(num_classes):
    model = torchvision.models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

model = create_model(num_classes).to(device)

def save_checkpoint(model, optimizer, epoch, loss, filename):
    filepath = os.path.join(checkpoint_dir, filename)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

def load_checkpoint(model, optimizer, filename):
    filepath = os.path.join(checkpoint_dir, filename)
    if os.path.isfile(filepath):
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded: {filepath}")
        return epoch, loss
    return 0, 0  # If no checkpoint found, start from beginning

# Load the latest checkpoint if it exists
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
if checkpoint_files:
    latest_checkpoint = max(checkpoint_files)
    start_epoch, start_loss = load_checkpoint(model, optimizer, latest_checkpoint)
else:
    start_epoch, start_loss = 0, 0


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()

num_epochs = 10
total_batches = len(train_loader)

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            
            pbar.update(1)
            pbar.set_postfix({'loss': running_loss / (i+1)})
            
            if (i + 1) % (total_batches // 10) == 0:
                save_checkpoint(model, optimizer, epoch, running_loss / (i+1), f"checkpoint_epoch_{epoch+1}_{int((i+1)/total_batches*100)}.pth")
    
    print(f'Epoch [{epoch + 1}/{num_epochs}] Loss: {running_loss / total_batches:.4f}')
    save_checkpoint(model, optimizer, epoch + 1, running_loss / total_batches, f"checkpoint_epoch_{epoch+1}_complete.pth")

print('Finished Training')


def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

# Evaluate on validation set
val_accuracy = evaluate(model, val_loader)
print(f'Validation Accuracy: {val_accuracy:.2f}%')

# Evaluate on test set
test_accuracy = evaluate(model, test_loader)
print(f'Test Accuracy: {test_accuracy:.2f}%')

torch.save(model.state_dict(), '/content/drive/MyDrive/bird_classification_final_model.pth')
print("Final model saved to Google Drive")