In [1]:
!pip install scikit-learn matplotlib

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
[0m

In [3]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import time
import shutil
from torchvision.models import ResNet18_Weights

DATA_DIR = '/root/Aerial_Landscapes'
BATCH_SIZE = 32
NUM_CLASSES = 15
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = './model_output/resnet18.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


illegal_dir = os.path.join(DATA_DIR, '.ipynb_checkpoints')
if os.path.exists(illegal_dir):
    shutil.rmtree(illegal_dir)
else:
    _ = print("")


transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

from torchvision.datasets import ImageFolder

def is_valid_image(filename):
    return filename.lower().endswith(('.jpg', '.jpeg', '.png'))

full_dataset = ImageFolder(root=DATA_DIR, transform=transform_train, is_valid_file=is_valid_image)
class_names = full_dataset.classes

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = transform_test  # Use test transform in validation

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train():
    model.train()
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
    
    acc_list = []
    loss_list = []

    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        acc_list.append(acc)
        loss_list.append(running_loss)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%")

    np.save("acc_resnet18.npy", np.array(acc_list))
    np.save("loss_resnet18.npy", np.array(loss_list))
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")

def evaluate():
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    print("Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

start = time.time()
train()
evaluate()
print(f"Total time: {(time.time() - start):.2f}s")



Epoch 1/10, Loss: 231.3830, Accuracy: 76.60%
Epoch 2/10, Loss: 136.9895, Accuracy: 86.18%
Epoch 3/10, Loss: 96.6853, Accuracy: 90.31%
Epoch 4/10, Loss: 75.7803, Accuracy: 91.94%
Epoch 5/10, Loss: 68.6252, Accuracy: 92.99%
Epoch 6/10, Loss: 70.3728, Accuracy: 93.14%
Epoch 7/10, Loss: 40.5753, Accuracy: 95.75%
Epoch 8/10, Loss: 38.4033, Accuracy: 96.17%
Epoch 9/10, Loss: 51.2616, Accuracy: 94.53%
Epoch 10/10, Loss: 29.4431, Accuracy: 96.95%
Model saved to ./model_output/resnet18.pth
Classification Report:
              precision    recall  f1-score   support

 Agriculture       0.97      0.95      0.96       156
     Airport       0.94      0.66      0.77       157
       Beach       0.83      0.99      0.90       156
        City       0.89      0.88      0.89       155
      Desert       0.92      0.88      0.90       162
      Forest       0.95      0.94      0.95       178
   Grassland       0.92      0.93      0.92       150
     Highway       0.88      0.89      0.89       147
   