In [1]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import efficientnet_b0
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

DATA_DIR = '/root/Aerial_Landscapes'
BATCH_SIZE = 32
NUM_CLASSES = 15
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = './efficientnet_b0.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ImageFolder(DATA_DIR, transform=transform_train)
class_names = dataset.classes

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
val_dataset.dataset.transform = transform_val

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = efficientnet_b0(weights='IMAGENET1K_V1')
model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train():
    model.train()
    for epoch in range(NUM_EPOCHS):
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%")

    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")

def evaluate():
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    print("Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

if __name__ == '__main__':
    train()
    evaluate()


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [01:21<00:00, 264kB/s]


Epoch 1/10, Loss: 143.1222, Accuracy: 86.03%
Epoch 2/10, Loss: 57.1420, Accuracy: 94.47%
Epoch 3/10, Loss: 44.3237, Accuracy: 95.73%
Epoch 4/10, Loss: 37.5180, Accuracy: 95.89%
Epoch 5/10, Loss: 34.5403, Accuracy: 96.62%
Epoch 6/10, Loss: 24.6164, Accuracy: 97.67%
Epoch 7/10, Loss: 31.2322, Accuracy: 96.96%
Epoch 8/10, Loss: 25.8587, Accuracy: 97.21%
Epoch 9/10, Loss: 17.4619, Accuracy: 98.31%
Epoch 10/10, Loss: 21.9135, Accuracy: 97.76%
Model saved to ./efficientnet_b0.pth
Classification Report:
              precision    recall  f1-score   support

 Agriculture       0.97      0.97      0.97       143
     Airport       1.00      0.90      0.95       149
       Beach       0.99      1.00      1.00       150
        City       0.97      0.96      0.96       152
      Desert       0.99      0.98      0.99       186
      Forest       0.99      0.96      0.98       168
   Grassland       0.94      0.99      0.96       162
     Highway       0.91      1.00      0.96       160
        Lak