### Importing Necessary Libraries

In [14]:
# Core Python Libraries
import os  # Operating system interactions, such as reading and writing files.
import shutil  # High-level file operations like copying and moving files.
import random  # Random number generation for various tasks.
import textwrap  # Formatting text into paragraphs of a specified width.
import warnings  # Warning control context manager.
import zipfile  # Work with ZIP archives.
import platform  # Access to underlying platform’s identifying data.
import itertools  # Functions creating iterators for efficient looping.
from dataclasses import dataclass  # Class decorator for adding special methods to classes.

# PyTorch-related Libraries (Deep Learning)
import torch  # Core PyTorch library for tensor computations.
import torch.nn as nn  # Neural network module for defining layers and architectures.
import torch.optim as optim  # Optimizer module for training models (SGD, Adam, etc.).
from torch.utils.data import Dataset, DataLoader, Subset, random_split  # Dataset and DataLoader for managing and batching data.
import torchvision # PyTorch's computer vision library.
from torchvision import datasets, transforms  # Datasets and transformations for image processing.
import torchvision.datasets as datasets  # Datasets for computer vision tasks.
import torchvision.transforms as transforms  # Transformations for image preprocessing.
from torchvision.utils import make_grid  # Make grid for displaying images.
import torchvision.models as models  # Pretrained models for transfer learning.
import torchvision.transforms.functional as TF  # Functional transformations for image preprocessing.
import torchsummary # PyTorch model summary for Keras-like model summary.
from torchvision.ops import sigmoid_focal_loss  # Focal loss for handling class imbalance in object detection.
from torchmetrics import MeanMetric  # Intersection over Union (IoU) metric for object detection.
from torchmetrics.classification import MultilabelF1Score, MultilabelRecall, MultilabelPrecision, MultilabelAccuracy  # Multilabel classification metrics.

import lightning.pytorch as pl
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Geospatial Data Processing Libraries
import rasterio  # Library for reading and writing geospatial raster data.
from rasterio.warp import calculate_default_transform, reproject  # Reprojection and transformation functions.
from rasterio.enums import Resampling  # Resampling methods used for resizing raster data.
from rasterio.plot import show  # Visualization of raster data.

# Data Manipulation and Analysis Libraries
import pandas as pd  # Data analysis and manipulation library for DataFrames and CSVs.
import numpy as np  # Numpy for array operations and numerical computations.
from sklearn.metrics import confusion_matrix, accuracy_score  # Evaluation metrics for classification models.

# Visualization Libraries
import matplotlib.pyplot as plt  # Plotting library for creating static and interactive visualizations.
import seaborn as sns  # High-level interface for drawing attractive statistical graphics.

# Utilities
from tqdm import tqdm  # Progress bar for loops and processes.
from PIL import Image  # Image handling, opening, manipulating, and saving.
import ast  # Abstract Syntax Trees for parsing Python code.
import requests  # HTTP library for sending requests.
import zstandard as zstd  # Zstandard compression for fast compression and decompression.
from collections import Counter # Counter for counting hashable objects.
import certifi  # Certificates for verifying HTTPS requests.
import ssl  # Secure Sockets Layer for secure connections.
import urllib.request  # URL handling for requests.
import kaggle # Kaggle API for downloading datasets.
import zipfile # Work with ZIP archives.

In [30]:
pl.seed_everything(42)  # Set seed for reproducibility.
main_path = 'D:\Datasets'
dataset_path = os.path.join(main_path, 'eurosat')


Seed set to 42


In [48]:
@dataclass
class Config:
    num_epochs = 10  # Number of epochs for training.
    batch_size = 64  # Batch size for training.
    learning_rate = 1e-3  # Learning rate for training.
    input_size = 224  # Input size for the model.
    resize = 256  # Resize size for the images.
    num_classes = 10  # Number of classes in the dataset.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Device for training.
    num_workers = 4  # Number of workers for the DataLoader.
    optimizer = optim.Adam # Optimizer for training.
    channels = 3 # Number of channels in the images.

    image_mean = [0.485, 0.456, 0.406]  # Image mean for normalization.
    image_std = [0.229, 0.224, 0.225]  # Image standard deviation for normalization.

    train_path = os.path.join(dataset_path, 'train.csv')
    val_path = os.path.join(dataset_path, 'validation.csv')
    test_path = os.path.join(dataset_path, 'test.csv')


In [10]:
# Disable SSL verification (not recommended for production)
ssl._create_default_https_context = ssl._create_unverified_context

# Download the EuroSAT dataset
euro_sat_dataset = datasets.EuroSAT(root="D:\Datasets", download=True, transform=transforms)

# Get the full path of the dataset
dataset_root_path = os.path.abspath(euro_sat_dataset.root)
dataset_full_path = os.path.join(dataset_root_path, 'eurosat', '2750')  # '2750' is the subdirectory where the dataset files are stored

Downloading https://madm.dfki.de/files/sentinel/EuroSAT.zip to D:\Datasets\eurosat\EuroSAT.zip


100%|██████████| 94280567/94280567 [00:04<00:00, 20527229.85it/s]


Extracting D:\Datasets\eurosat\EuroSAT.zip to D:\Datasets\eurosat


### Creating a Custom Dataset Class
This is a custom PyTorch Dataset class designed to load images and labels (if available) for each set. The PyTorch "Dataset" class is essential for efficient and organized data handling in machine learning tasks. It provides a standardized interface to load and preprocess data samples from various sources. Encapsulating the dataset into a single object simplifies data management and enables seamless integration with other PyTorch components like data loaders and models. 

In [32]:
class EuroSATDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        if self.transform:
            img = self.transform(img)

        return img, label

### Custom LightningDataModule class to Load dataset
This class inherits from Lightning’s LightningDataModule class and encapsulates the following steps:

1) Download the dataset.
2) Create train and validation splits.
3) Create a Dataset class object for each split with appropriate transformations.
4) Create DataLoader objects for each split.

The class methods are defined to do the following tasks:
1) prepare_data(..): This method is used for data preparation, like downloading and one-time preprocessing with the dataset. When training on a distributed GPU, this will be called from a single GPU.
2) setup(...):  When you want to perform data operations on every GPU, this method is apt for it will call from every GPU. For example, perform train/val/test splits.
3) train_dataloader(...): This method returns the train dataloader.
4) val_dataloader(...): This method returns validation dataloader(s).
5) test_dataloader(...):  This method returns test dataloader(s).

In [44]:
class EuroSATDataModule(pl.LightningDataModule):
    def __init__(self, *, num_classes=10, valid_pct=0.1, resize_to=(384, 384), batch_size=32, num_workers=0, pin_memory=False, shuffle_validation=False,):
        super().__init__()
        self.num_classes = num_classes
        self.valid_pct = valid_pct
        self.resize_to = resize_to
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.shuffle_validation = shuffle_validation

        self.train_tfms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(Config.input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(Config.image_mean, Config.image_std)
        ])

        self.valid_tfms = transforms.Compose([
            transforms.Resize(Config.resize),
            transforms.CenterCrop(Config.input_size),
            transforms.ToTensor(),
            transforms.Normalize(Config.image_mean, Config.image_std)
        ])
        
        self.test_tfms = transforms.Compose([
            transforms.Resize(Config.resize),
            transforms.CenterCrop(Config.input_size),
            transforms.ToTensor(),
            transforms.Normalize(Config.image_mean, Config.image_std)
        ])

    def prepare_data(self):
        if not os.path.exists(dataset_full_path):
            kaggle.api.dataset_download_files('apollo2506/eurosat-dataset', path=dataset_full_path, unzip=True)
    
    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        train_loader = DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        return train_loader
    
    def val_dataloader(self):
        val_loader = DataLoader(self.valid_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return val_loader

    def test_dataloader(self):
        test_loader = DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return test_loader

In [46]:
%%time

dm = EuroSATDataModule(
    num_classes=Config.num_classes,
    batch_size=32, 
    num_workers=0
)

# Donwload dataset.
dm.prepare_data()

# Split datset into training, validation set.
dm.setup()

CPU times: total: 0 ns
Wall time: 0 ns


In [49]:
data_df = pd.read_csv(Config.train_path)
data_df.head()

Unnamed: 0.1,Unnamed: 0,Filename,Label,ClassName
0,16257,AnnualCrop/AnnualCrop_142.jpg,0,AnnualCrop
1,3297,HerbaceousVegetation/HerbaceousVegetation_2835...,2,HerbaceousVegetation
2,17881,PermanentCrop/PermanentCrop_1073.jpg,6,PermanentCrop
3,2223,Industrial/Industrial_453.jpg,4,Industrial
4,4887,HerbaceousVegetation/HerbaceousVegetation_1810...,2,HerbaceousVegetation


In [51]:
data_df = pd.read_csv(Config.train_path)
category_counts = data_df['ClassName'].value_counts()
print(category_counts)

ClassName
AnnualCrop              2100
HerbaceousVegetation    2100
Residential             2100
SeaLake                 2100
Forest                  2100
PermanentCrop           1750
Industrial              1750
Highway                 1750
River                   1750
Pasture                 1400
Name: count, dtype: int64


### Custom Lightning Module Class for the Model
Create a class that contains the following main methods: 

In [None]:
class EuroSATModel(pl.LightningModule):
    def __init__(self, num_classes=10):
        super(EuroSATModel, self).__init__()
        self.model = models.resnet18(pretrained=True)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=Config.learning_rate)
        return optimizer
    
    def on_train_epoch_end(self) -> None:
        return super().on_train_epoch_end()
    def on_validation_epoch_end(self) -> None:
        return super().on_validation_epoch_end()

### Defining Transformations