In [None]:
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from pathlib import Path
import time

# Use GPU if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using: {device}")

In [None]:
# paths
DATA_PATH = Path("../data/raw/soil-classification/Orignal-Dataset")
OUTPUTS_PATH = Path("../outputs")
CHECKPOINT_PATH = OUTPUTS_PATH / "checkpoints"
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

# variables (ImageNet standard)
IMG_DEFAULT_SIZE = 256
IMG_CROP_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_CLASSES = 7

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(IMG_DEFAULT_SIZE),
    transforms.RandomCrop(IMG_CROP_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

val_transform = transforms.Compose([
    transforms.Resize(IMG_DEFAULT_SIZE),
    transforms.CenterCrop(IMG_CROP_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

In [None]:
full_dataset = ImageFolder(root=DATA_PATH, transform=train_transform)

print(f"Total images: {len(full_dataset)}")
print(f"Classes: {full_dataset.classes}")

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset)-train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = val_transform # change transform for validation set

print(f"Training set: {len(train_dataset)}")
print(f"Validation set: {len(val_dataset)}")

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

print(f"Train batch size: {len(train_loader)}")
print(f"Val batch size: {len(val_loader)}")

In [None]:
# load pretrained model
model = models.efficientnet_b0("IMAGENET1K_V1")

for param in model.parameters():
    param.requires_grad = False

# replace classifier for 7 soil types
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, NUM_CLASSES)

model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("== Model used: EfficientNet-B0")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {train_params:,}")
print(f"Percentage: {100*train_params/total_params:.2f}%")


In [None]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# track training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0
start_time = time.time()

print("Training started...")
print("-" * 40)

for epoch in range(NUM_EPOCHS):
    # training phase
    model.train()
    running_loss = 0
    correct_train = 0
    total_train = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # backward pass
        loss.backward()
        optimizer.step()

        # track metrics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    # validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    # metrics
    train_loss = running_loss / total_train
    train_acc = 100 * correct_train / total_train
    val_loss = val_loss / val_total
    val_acc = 100 * val_correct / val_total

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    #TODO: add learning rate schedul

    # save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), CHECKPOINT_PATH/'best_model.pth')
        print(f"Epoch {epoch+1:02d}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% ** BEST **")
    else:
        print(f"Epoch {epoch+1:02d}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

elapsed = time.time() - start_time
print("-" * 40)
print(f"Training done in {elapsed/60:.1f} min")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

In [None]:
# plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and validation loss')
axes[0].legend()
axes[0].grid(True)

# accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and validation accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(OUTPUTS_PATH/'training_curves.png')
plt.show()