In [2]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import efficientnet_b0
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance, ImageOps

DATA_DIR = '/root/Aerial_Landscapes'
BATCH_SIZE = 32
NUM_CLASSES = 15
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def augmix(image, severity=1, width=3, depth=-1, alpha=1.):
    ws = np.float32(np.random.dirichlet([alpha] * width))
    m = np.float32(np.random.beta(alpha, alpha))
    mix = torch.zeros_like(transforms.ToTensor()(image))
    for i in range(width):
        image_aug = image.copy()
        d = depth if depth > 0 else np.random.randint(1, 4)
        for _ in range(d):
            op = random.choice([
                lambda x: ImageOps.autocontrast(x),
                lambda x: ImageOps.equalize(x),
                lambda x: ImageEnhance.Brightness(x).enhance(1 + 0.1 * severity * (random.random() - 0.5)),
                lambda x: ImageEnhance.Contrast(x).enhance(1 + 0.1 * severity * (random.random() - 0.5)),
                lambda x: x.rotate(10 * (random.random() - 0.5)),
                lambda x: x.transpose(Image.FLIP_LEFT_RIGHT)
            ])
            image_aug = op(image_aug)
        mix += ws[i] * transforms.ToTensor()(image_aug)

    mixed = (1 - m) * transforms.ToTensor()(image) + m * mix
    return mixed

class AugMixTransform:
    def __call__(self, img):
        img = img.resize((224, 224))
        return augmix(img)

def get_transform(augment_type):
    if augment_type == 'randomcrop':
        return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    elif augment_type == 'colorjitter':
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ColorJitter(0.3, 0.3, 0.3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    elif augment_type == 'augmix':
        return AugMixTransform()
    else:
        raise ValueError("Unknown augment_type")

def train_efficientnet(augment_type):
    print(f"Training EfficientNet with: {augment_type}")
    transform_train = get_transform(augment_type)
    transform_val = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    dataset = ImageFolder(DATA_DIR, transform=transform_train)
    class_names = dataset.classes
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    val_dataset.dataset.transform = transform_val
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = efficientnet_b0(weights='IMAGENET1K_V1')
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    epoch_acc_list = []
    epoch_loss_list = []

    for epoch in range(NUM_EPOCHS):
        model.train()
        correct, total, running_loss = 0, 0, 0.0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        acc = 100. * correct / total
        epoch_acc_list.append(acc)
        epoch_loss_list.append(running_loss)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Loss: {running_loss:.2f} | Accuracy: {acc:.2f}%")

    np.save(f"acc_{augment_type}.npy", np.array(epoch_acc_list))
    np.save(f"loss_{augment_type}.npy", np.array(epoch_loss_list))
    torch.save(model.state_dict(), f'./efficientnet_{augment_type}.pth')

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print(f"Report for: {augment_type}")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    print(confusion_matrix(all_labels, all_preds))
    
    from sklearn.metrics import accuracy_score, f1_score
    final_acc = accuracy_score(all_labels, all_preds)
    final_f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Final Accuracy: {final_acc:.4f}, Macro F1: {final_f1:.4f}")
    with open(f"score_{augment_type}.txt", "w") as f:
        f.write(f"Accuracy: {final_acc:.4f}\n")
        f.write(f"Macro F1: {final_f1:.4f}\n")


if __name__ == '__main__':
    train_efficientnet('randomcrop')
    train_efficientnet('colorjitter')
    train_efficientnet('augmix')


Training EfficientNet with: randomcrop
Epoch 1/10 | Loss: 135.61 | Accuracy: 86.80%
Epoch 2/10 | Loss: 60.78 | Accuracy: 93.96%
Epoch 3/10 | Loss: 41.79 | Accuracy: 95.67%
Epoch 4/10 | Loss: 39.95 | Accuracy: 96.04%
Epoch 5/10 | Loss: 27.43 | Accuracy: 97.20%
Epoch 6/10 | Loss: 27.52 | Accuracy: 97.11%
Epoch 7/10 | Loss: 27.30 | Accuracy: 97.14%
Epoch 8/10 | Loss: 19.34 | Accuracy: 97.94%
Epoch 9/10 | Loss: 23.19 | Accuracy: 97.60%
Epoch 10/10 | Loss: 27.91 | Accuracy: 97.11%
Report for: randomcrop
              precision    recall  f1-score   support

 Agriculture       0.97      0.96      0.97       141
     Airport       0.97      0.83      0.90       162
       Beach       1.00      0.95      0.97       162
        City       0.93      0.97      0.95       147
      Desert       0.96      1.00      0.98       154
      Forest       0.99      0.96      0.98       168
   Grassland       0.93      0.98      0.95       158
     Highway       0.95      0.95      0.95       177
        L