In [None]:
from pathlib import Path
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder

DATA_PATH = Path("../data/raw/soil-classification/Orignal-Dataset")

# Use GPU if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using: {device}")

In [None]:
IMG_SIZE = 224

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# No augmentation for val/test
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
full_dataset = ImageFolder(DATA_PATH, transform=train_transform)

print(f"Classes: {full_dataset.classes}")
print(f"Mapping: {full_dataset.class_to_idx}")
print(f"Total: {len(full_dataset)} images")

In [None]:
# Split dataset
n = len(full_dataset)
train_size = int(0.7 * n)
val_size = int(0.15 * n)
test_size = n - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(24)
)

print(f"Train: {train_size} | Val: {val_size} | Test: {test_size}")

In [None]:
# Class distribution
class_counts = Counter([full_dataset.targets[i] for i in train_dataset.indices])
for i, name in enumerate(full_dataset.classes):
    print(f"{name}: {class_counts[i]}")

# Weights to handle imbalance
total = sum(class_counts.values())
weights = [total / class_counts[i] for i in range(len(full_dataset.classes))]
class_weights = torch.tensor(weights, dtype=torch.float32)
class_weights = class_weights / class_weights.sum() * len(class_weights)
print(f"\nClass weights: {[round(w, 2) for w in class_weights.tolist()]}")

In [None]:
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

images, labels = next(iter(train_loader))
print(images.shape, labels.shape)

In [None]:
# Visualize batch
def denormalize(img):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return img * std + mean

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    img = denormalize(images[i]).permute(1, 2, 0).numpy()
    ax.imshow(np.clip(img, 0, 1))
    ax.set_title(full_dataset.classes[labels[i]])
    ax.axis('off')
plt.tight_layout()
plt.show()