Satellite Image Classification

In [None]:

!pip install kagglehub


In [None]:
!pip install torch torchvision torchaudio


In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the root data directory
data_dir = "data"  # Make sure this path is correct

# Define transformations for image preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor(),  # Convert to PyTorch tensors
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize (optional)
])

# Load dataset from folders
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Create a DataLoader for batching
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Print class names
print("Classes:", dataset.classes)

# Check a sample image
images, labels = next(iter(dataloader))
print("Batch shape:", images.shape)


1️⃣ Understanding the Dataset Structure

* data/
* │── cloudy/         # Contains cloudy sky images
* │── desert/         # Contains desert landscape images
* │── green_area/     # Contains green vegetation images
* │── water/          # Contains water bodies images

- Each folder represents a class label.
- Images are automatically labeled based on folder names.

2️⃣ Check Dataset Distribution

- Before training, we should check if the classes are balanced (i.e., each category has a similar number of images).
- We can count the number of images in each category.

In [None]:
from collections import Counter

# Count images in each class
class_counts = Counter([label for _, label in dataset])

# Print class distribution
for class_name, count in zip(dataset.classes, class_counts.values()):
    print(f"{class_name}: {count} images")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Get one batch of images and labels
images, labels = next(iter(dataloader))

# Convert images from tensor to NumPy for visualization
images = images.permute(0, 2, 3, 1).numpy()  # Change shape to (batch, height, width, channels)

# Plot 8 sample images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for img, label, ax in zip(images[:8], labels[:8], axes):
    ax.imshow((img * 0.5) + 0.5)  # Unnormalize the image
    ax.set_title(f"Class: {dataset.classes[label]}")
    ax.axis("off")

plt.tight_layout()
plt.show()
