### Importing Of Libraries and Modules

In [31]:
# Core Library modules
import os  # Operating system interactions, such as reading and writing files.
import shutil  # High-level file operations like copying and moving files.
import random  # Random number generation for various tasks.
import textwrap  # Formatting text into paragraphs of a specified width.
import warnings  # Warning control context manager.
import zipfile  # Work with ZIP archives.
import platform  # Access to underlying platform’s identifying data.
import itertools  # Functions creating iterators for efficient looping.
from dataclasses import dataclass  # Class decorator for adding special methods to classes.

# PyTorch and Deep Learning Libaries
import torch  # Core PyTorch library for tensor computations.
import torch.nn as nn  # Neural network module for defining layers and architectures.
from torch.nn import functional as F  # Functional module for defining functions and loss functions.
import torch.optim as optim  # Optimizer module for training models (SGD, Adam, etc.).
from torch.utils.data import Dataset, DataLoader, Subset, random_split  # Data handling and batching
import torchvision  # PyTorch's computer vision library.
from torchvision import datasets, transforms  # Image datasets and transformations.
import torchvision.datasets as datasets  # Specific datasets for vision tasks.
import torchvision.transforms as transforms  # Transformations for image preprocessing.
from torchvision.utils import make_grid  # Grid for displaying images.
import torchvision.models as models  # Pretrained models for transfer learning.
from torchvision.datasets import MNIST, EuroSAT  # Standard datasets.
import torchvision.transforms.functional as TF  # Functional transformations.
from torchvision.models import ResNet18_Weights  # ResNet-18 model with pretrained weights.
from torchsummary import summary  # Model summary.
import torchsummary  # Model summaries.
import torchmetrics  # Model evaluation metrics.
from torchmetrics import MeanMetric, Accuracy  # Accuracy metrics.
from torchmetrics.classification import (
    MulticlassF1Score, MulticlassRecall, MulticlassPrecision, MulticlassAccuracy
)  # Classification metrics.
from torchviz import make_dot  # Model visualization.
from torchvision.ops import sigmoid_focal_loss  # Focal loss for class imbalance.
from torchcam.methods import GradCAM  # Grad-CAM for model interpretability.
from torchcam.utils import overlay_mask  # Overlay mask for visualizations.
import pytorch_lightning as pl  # Training management.
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, Callback  # Callbacks.
from pytorch_lightning.loggers import TensorBoardLogger  # Logger for TensorBoard.

# Geospatial Data Processing Libraries
import rasterio  # Reading and writing geospatial raster data.
from rasterio.warp import calculate_default_transform, reproject  # Reprojection and transformation.
from rasterio.enums import Resampling  # Resampling for raster resizing.
from rasterio.plot import show  # Visualization of raster data.

# Data Manipulation, Analysis and Visualization Libraries
import pandas as pd  # Data analysis and manipulation.
import numpy as np  # Array operations and computations.
from sklearn.metrics import confusion_matrix, accuracy_score  # Evaluation metrics.
import matplotlib.pyplot as plt  # Static and interactive plotting.
import seaborn as sns  # High-level interface for statistical graphics.

# Utility Libraries
from tqdm import tqdm  # Progress bar for loops.
from PIL import Image  # Image handling and manipulation.
import ast  # Parsing Python code.
import requests  # HTTP requests.
import zstandard as zstd  # Compression and decompression.
from collections import Counter  # Counting hashable objects.
import certifi  # Certificates for HTTPS.
import ssl  # Secure connections.
import urllib.request  # URL handling.
import kaggle  # Kaggle API for datasets.
from IPython.display import Image  # Display images in notebooks.
from pathlib import Path # File system path handling.


### Setting Seed and Device

In [5]:
# Set seed for reproducibility
SEED = 42  
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device} {'(GPU: ' + torch.cuda.get_device_name(0) + ')' if device.type == 'cuda' else ''}")

Device: cuda (GPU: NVIDIA GeForce RTX 3050)


In [18]:
@dataclass
class DatasetConfig:
    dataset_path: str = r'C:\Users\isaac\Desktop\BigEarthTests\Subsets\50%'
    combined_path: str = r'C:\Users\isaac\Desktop\BigEarthTests\Subsets\50%\CombinedImages'
    metadata_path: str =r'C:\Users\isaac\Desktop\BigEarthTests\Subsets\metadata_50_percent.csv'
    metadata_csv = pd.read_csv(metadata_path)
    img_size: int = 120
    img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    num_classes: int = 19
    band_channels: int = 13

@dataclass
class ModelConfig:
    batch_size: int = 32
    num_epochs: int = 10
    model_name: str = 'resnet18'
    num_workers: int = os.cpu_count() // 2

    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=DatasetConfig.img_mean, std=DatasetConfig.img_std)
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=DatasetConfig.img_mean, std=DatasetConfig.img_std)
    ])

    test_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=DatasetConfig.img_mean, std=DatasetConfig.img_std)
    ])


In [None]:
class BigEarthNetSubset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = list(Path(root_dir).rglob("*.tif"))
        self.metadata = pd.read_csv(DatasetConfig.metadata_path)

        # Create a mapping from patch_id to labels
        self.patch_to_labels = dict(zip(self.metadata['patch_id'], self.metadata['labels']))
        self.image_paths = list(Path(root_dir).rglob("*.tif"))
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        with rasterio.open(image_path) as src:
            image = src.read() 
             
        label = self.get_label(image_path)  

        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_label(self, image_path):
        patch_id = image_path.stem
        labels = self.patch_to_labels.get(patch_id, None)
        return labels

In [74]:
dataset = BigEarthNetSubset(DatasetConfig.combined_path)
print(f"Dataset length: {len(dataset)}")

Dataset length: 2352


In [72]:
image, label = dataset[5]
print(f"Image shape: {image.size}, Label: {label}")

Image shape: 172800, Label: ['Broad-leaved forest', 'Coniferous forest', 'Mixed forest', 'Transitional woodland, shrub']


In [76]:
image

array([[[ 231,  231,  231, ...,  189,  189,  189],
        [ 231,  231,  231, ...,  189,  189,  189],
        [ 231,  231,  231, ...,  189,  189,  189],
        ...,
        [ 150,  150,  150, ...,  165,  165,  165],
        [ 150,  150,  150, ...,  165,  165,  165],
        [ 150,  150,  150, ...,  165,  165,  165]],

       [[ 210,  250,  310, ...,  183,  177,  209],
        [ 194,  130,  171, ...,  185,  203,  228],
        [ 189,  150,  183, ...,  184,  225,  256],
        ...,
        [ 133,  128,  144, ...,  105,  198,  285],
        [ 116,  131,  149, ...,  169,  290,  281],
        [ 109,  141,  159, ...,  233,  257,  274]],

       [[ 477,  522,  557, ...,  469,  482,  532],
        [ 473,  404,  428, ...,  446,  520,  618],
        [ 453,  432,  469, ...,  463,  590,  644],
        ...,
        [ 319,  317,  373, ...,  318,  479,  636],
        [ 343,  330,  382, ...,  485,  630,  672],
        [ 314,  341,  384, ...,  584,  634,  646]],

       ...,

       [[3047, 3047, 304

In [75]:
dataset[5]

(array([[[ 231,  231,  231, ...,  189,  189,  189],
         [ 231,  231,  231, ...,  189,  189,  189],
         [ 231,  231,  231, ...,  189,  189,  189],
         ...,
         [ 150,  150,  150, ...,  165,  165,  165],
         [ 150,  150,  150, ...,  165,  165,  165],
         [ 150,  150,  150, ...,  165,  165,  165]],
 
        [[ 210,  250,  310, ...,  183,  177,  209],
         [ 194,  130,  171, ...,  185,  203,  228],
         [ 189,  150,  183, ...,  184,  225,  256],
         ...,
         [ 133,  128,  144, ...,  105,  198,  285],
         [ 116,  131,  149, ...,  169,  290,  281],
         [ 109,  141,  159, ...,  233,  257,  274]],
 
        [[ 477,  522,  557, ...,  469,  482,  532],
         [ 473,  404,  428, ...,  446,  520,  618],
         [ 453,  432,  469, ...,  463,  590,  644],
         ...,
         [ 319,  317,  373, ...,  318,  479,  636],
         [ 343,  330,  382, ...,  485,  630,  672],
         [ 314,  341,  384, ...,  584,  634,  646]],
 
        ...,


In [77]:
class BigEarthNetSubsetDataModule(pl.LightningDataModule):
    def __init__(self):
        pass

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass

    def test_dataloader(self):
        pass

In [78]:
class BigEarthNetSubsetModel(pl.LightningModule):
    def __init__(self):
        pass

    def forward(self, x):
        pass

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        pass

    def training_epoch_end(self, outputs):
        pass

    def validation_epoch_end(self, outputs):
        pass

    def test_epoch_end(self, outputs):
        pass

    def predict(self, x):
        pass

In [None]:
class BigEarthNetSubsetCallback(Callback):
    def __init__(self):
        pass
    def on_epoch_end(self, trainer, pl_module):
        pass
    def on_train_end(self, trainer, pl_module):
        pass