In [None]:
# Core Library modules
import os  # Operating system interactions, such as reading and writing files.
import shutil  # High-level file operations like copying and moving files.
import random  # Random number generation for various tasks.
import textwrap  # Formatting text into paragraphs of a specified width.
import warnings  # Warning control context manager.
import zipfile  # Work with ZIP archives.
import platform  # Access to underlying platform’s identifying data.
import itertools  # Functions creating iterators for efficient looping.
from dataclasses import dataclass, field  # Class decorator for adding special methods to classes.

# PyTorch and Deep Learning Libaries
import torch  # Core PyTorch library for tensor computations.
import torch.nn as nn  # Neural network module for defining layers and architectures.
from torch.nn import functional as F  # Functional module for defining functions and loss functions.
import torch.optim as optim  # Optimizer module for training models (SGD, Adam, etc.).
from torch.utils.data import Dataset, DataLoader, Subset, random_split  # Data handling and batching
import torchvision  # PyTorch's computer vision library.
from torchvision import datasets, transforms  # Image datasets and transformations.
import torchvision.datasets as datasets  # Specific datasets for vision tasks.
import torchvision.transforms as transforms  # Transformations for image preprocessing.
from torchvision.utils import make_grid  # Grid for displaying images.
import torchvision.models as models  # Pretrained models for transfer learning.
from torchvision.datasets import MNIST, EuroSAT  # Standard datasets.
import torchvision.transforms.functional as TF  # Functional transformations.
from torchvision.models import ResNet18_Weights  # ResNet-18 model with pretrained weights.
from torchsummary import summary  # Model summary.
import torchmetrics  # Model evaluation metrics.
from torchmetrics import MeanMetric, Accuracy  # Accuracy metrics.
from torchmetrics.classification import (
    MultilabelF1Score, MultilabelRecall, MultilabelPrecision, MultilabelAccuracy
)  # Classification metrics.
from torchviz import make_dot  # Model visualization.
from torchvision.ops import sigmoid_focal_loss  # Focal loss for class imbalance.
from torchcam.methods import GradCAM  # Grad-CAM for model interpretability.
from torchcam.utils import overlay_mask  # Overlay mask for visualizations.
import pytorch_lightning as pl  # Training management.
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, Callback  # Callbacks.
from pytorch_lightning.loggers import TensorBoardLogger  # Logger for TensorBoard.

# Geospatial Data Processing Libraries
import rasterio  # Reading and writing geospatial raster data.
from rasterio.warp import calculate_default_transform, reproject  # Reprojection and transformation.
from rasterio.enums import Resampling  # Resampling for raster resizing.
from rasterio.plot import show  # Visualization of raster data.

# Data Manipulation, Analysis and Visualization Libraries
import pandas as pd  # Data analysis and manipulation.
import numpy as np  # Array operations and computations.
from sklearn.metrics import confusion_matrix, accuracy_score  # Evaluation metrics.
import matplotlib.pyplot as plt  # Static and interactive plotting.
import seaborn as sns  # High-level interface for statistical graphics.

# Utility Libraries
from tqdm import tqdm  # Progress bar for loops.
from PIL import Image  # Image handling and manipulation.
import ast  # Parsing Python code.
import requests  # HTTP requests.
import zstandard as zstd  # Compression and decompression.
from collections import Counter  # Counting hashable objects.
import certifi  # Certificates for HTTPS.
import ssl  # Secure connections.
import urllib.request  # URL handling.
import kaggle  # Kaggle API for datasets.
from IPython.display import Image  # Display images in notebooks.
from pathlib import Path # File system path handling.
from typing import Dict, List, Tuple  # Type hints.
import sys  # System-specific parameters and functions.
import time # Time access and conversions.
import logging # Logging facility for Python.
import json # JSON encoder and decoder.
from torch.optim.lr_scheduler import ReduceLROnPlateau
from contextlib import redirect_stdout
# Custom Libraries

In [None]:
# Set seed for reproducibility
SEED = 42  
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.set_float32_matmul_precision('medium')

# Set environment variables
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "upb"


# Render plots
%matplotlib inline
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device} {'(GPU: ' + torch.cuda.get_device_name(0) + ')' if device.type == 'cuda' else ''}")

In [None]:
def set_random_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def clean_and_parse_labels(label_string):
    cleaned_labels = label_string.replace(" '", ", '").replace("[", "[").replace("]", "]")
    return ast.literal_eval(cleaned_labels)

def calculate_class_weights(metadata_path):
    metadata_csv = pd.read_csv(metadata_path)
    metadata_csv['labels'] = metadata_csv['labels'].apply(clean_and_parse_labels)

    class_labels = set()
    for labels in metadata_csv['labels']:
        class_labels.update(labels)

    label_counts = metadata_csv['labels'].explode().value_counts()
    total_counts = label_counts.sum()
    class_weights = {label: total_counts / count for label, count in label_counts.items()}
    class_weights_array = np.array([class_weights[label] for label in class_labels])

    return class_labels, class_weights, class_weights_array, metadata_csv

# Helper functions
def denormalize(tensors, *, mean, std):
    for c in range(DatasetConfig.band_channels):
        tensors[:, c, :, :].mul_(std[c]).add_(mean[c])

    return torch.clamp(tensors, min=0.0, max=1.0)

def encode_label(label: list, num_classes=DatasetConfig.num_classes):
    target = torch.zeros(num_classes)
    for l in label:
        if l in DatasetConfig.class_labels_dict:
            target[DatasetConfig.class_labels_dict[l]] = 1.0
    return target

def decode_target(
    target: list,
    text_labels: bool = False,
    threshold: float = 0.4,
    cls_labels: dict = None,
):
    result = []
    for i, x in enumerate(target):
        if x >= threshold:
            if text_labels:
                result.append(cls_labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return " ".join(result)


def get_band_indices(band_names, all_band_names):
    return [all_band_names.index(band) for band in band_names]



def get_labels_for_image(image_path, model, transform, patch_to_labels):
    # Load and preprocess the image
    with rasterio.open(image_path) as src:
        bands = [2, 3, 4]  # Bands to combine for display
        image = np.stack([src.read(band) for band in bands], axis=-1)
        image = transform(image).unsqueeze(0).to(model.device)  # Add batch dimension and move to device

    # Get the predicted labels
    model.eval()
    with torch.no_grad():
        preds = model(image).sigmoid() > 0.5  # Apply sigmoid and threshold at 0.5
        preds = preds.cpu().numpy().astype(int).flatten()

    # Get the true labels
    patch_id = os.path.basename(image_path).split('.')[0]
    true_labels = patch_to_labels[patch_id]

    return preds, true_labels, image

def display_image(image_path):
    with rasterio.open(image_path) as src:
        bands = [2, 3, 4]  # Bands to combine for display
        image = np.stack([src.read(band) for band in bands], axis=-1)
        plt.imshow(image)
        plt.title("Image with Bands 2, 3, and 4")
        plt.show()

def display_image_and_labels(image_path, model, transform, patch_to_labels):
    # Display the image
    display_image(image_path)

    # Get predicted and true labels
    preds, true_labels, _ = get_labels_for_image(image_path, model, transform, patch_to_labels)
    print(f"Predicted Labels: {preds}")
    print(f"True Labels: {true_labels}")

def extract_number(string):
    number_str = string.split('%')[0]
    try:
        number = float(number_str)
        if number.is_integer():
            return int(number)
        return number
    except ValueError:
        raise ValueError(f"Cannot extract a number from the string: {string}")
    

# Define the hook function
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# Visualize activations function
def visualize_activations(layer_names, activations):
    images_per_row = 16
    for layer_name in layer_names:
        layer_activation = activations[layer_name].squeeze().cpu().numpy()
        n_features = layer_activation.shape[0] 
        size = layer_activation.shape[1] 
        n_cols = n_features // images_per_row  
        display_grid = np.zeros((size * n_cols, images_per_row * size))
        for col in range(n_cols):  
            for row in range(images_per_row):
                channel_image = layer_activation[col * images_per_row + row]
                channel_image -= channel_image.mean() 
                channel_image /= channel_image.std()
                channel_image *= 64
                channel_image += 128
                channel_image = np.clip(channel_image, 0, 255).astype('uint8')
                display_grid[col * size : (col + 1) * size,  
                             row * size : (row + 1) * size] = channel_image
        scale = 1. / size
        plt.figure(figsize=(scale * display_grid.shape[1],
                            scale * display_grid.shape[0]))
        plt.title(layer_name)
        plt.grid(False)
        plt.imshow(display_grid, aspect='auto', cmap='viridis')
        plt.show()


In [None]:
class BandNormalisation:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        for i in range(image.shape[0]):
            print(f"Channel {i} - Mean: {self.mean[i]}, Std: {self.std[i]}")
            image[i] = (image[i] - self.mean[i]) / self.std[i]
            print(f"Channel {i} after normalization: {image[i]}")
        return image
    
class BandUnnormalisation:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        for i in range(image.shape[0]):
            image[i] = (image[i] * self.std[i]) + self.mean[i]
        return image

In [None]:
# Description: Configuration file for the project
@dataclass
class DatasetConfig:
    metadata_path = r"C:\\Users\\isaac\\Desktop\BigEarthTests\\50%_BigEarthNet\\metadata_50_percent.csv"
    dataset_paths = {
        "0.5": r"C:\Users\isaac\Desktop\BigEarthTests\0.5%_BigEarthNet\CombinedImages",
        "1": r"C:\Users\isaac\Desktop\BigEarthTests\1%_BigEarthNet\CombinedImages",
        "5": r"C:\Users\isaac\Desktop\BigEarthTests\5%_BigEarthNet\CombinedImages",
        "10": r"C:\Users\isaac\Desktop\BigEarthTests\10%_BigEarthNet\CombinedImages",
        "50": r"C:\Users\isaac\Desktop\BigEarthTests\50%_BigEarthNet\CombinedImages",
        "100": r"C:\Users\isaac\Desktop\BigEarthTests\100%_BigEarthNet\CombinedImages"
    }
    metadata_paths = {
        "0.5": r"C:\Users\isaac\Desktop\BigEarthTests\0.5%_BigEarthNet\metadata_0.5_percent.csv",
        "1": r"C:\Users\isaac\Desktop\BigEarthTests\1%_BigEarthNet\metadata_1_percent.csv",
        "5": r"C:\Users\isaac\Desktop\BigEarthTests\5%_BigEarthNet\metadata_5_percent.csv",
        "10": r"C:\Users\isaac\Desktop\BigEarthTests\10%_BigEarthNet\metadata_10_percent.csv",
        "50": r"C:\Users\isaac\Desktop\BigEarthTests\50%_BigEarthNet\metadata_50_percent.csv",
        "100": r"C:\Users\isaac\Desktop\BigEarthTests\100%_BigEarthNet\metadata_100_percent.csv"
    }
    unwanted_metadata_file: str = r'C:\Users\isaac\Downloads\metadata_for_patches_with_snow_cloud_or_shadow.parquet'
    unwanted_metadata_csv = pd.read_parquet(unwanted_metadata_file)

    class_labels = calculate_class_labels(pd.read_csv(metadata_path))
    class_labels = class_labels
    class_labels_dict = {label: idx for idx, label in enumerate(class_labels)}
    reversed_class_labels_dict = {idx: label for label, idx in class_labels_dict.items()}

    num_classes: int = 19
    band_channels: int = 12
    valid_pct: float = 0.1
    img_size: int = 120

    rgb_bands = ["B04", "B03", "B02"]
    rgb_nir_bands = ["B04", "B03", "B02", "B08"]
    rgb_swir_bands = ["B04", "B03", "B02", "B11", "B12"]
    rgb_nir_swir_bands = ["B04", "B03", "B02", "B08", "B11", "B12"]
    all_imp_bands = [ "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"]
    all_bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12"]
    
    band_stats = {
        "mean": {
            "B01": 359.93681858037576,
            "B02": 437.7795146920668,
            "B03": 626.9061237185847,
            "B04": 605.0589129818594,
            "B05": 971.6512098450492,
            "B06": 1821.9817358749056,
            "B07": 2108.096240315571,
            "B08": 2256.3215618504346,
            "B8A": 2310.6351913265307,
            "B09": 2311.6085833217353,
            "B11": 1608.6865167942176,
            "B12": 1017.1259618291762
        },
        "std": {
            "B01": 583.5085769396974,
            "B02": 648.4384481402268,
            "B03": 639.2669907766995,
            "B04": 717.5748664544205,
            "B05": 761.8971822905785,
            "B06": 1090.758232889144,
            "B07": 1256.5524552734478,
            "B08": 1349.2050493390414,
            "B8A": 1287.1124261320342,
            "B09": 1297.654379610044,
            "B11": 1057.3350765979644,
            "B12": 802.1790763840752
        }
    }

@dataclass
class ModelConfig:
    num_epochs: int = 10
    batch_size: int = 32
    num_workers: int = os.cpu_count() // 2
    learning_rate: float = 0.0001
    momentum: float = 0.9
    weight_decay: float = 1e-4
    lr_step_size: int = 7
    lr_factor: float = 0.1
    patience: int = 5
    lr_patience: int = 5
    dropout: float = 0.5

    model_names: list = field(default_factory=lambda: [
        'resnet18', 
        'resnet34', 
        'resnet50', 
        'resnet101', 
        'resnet152', 
        'densenet121', 
        'densenet169', 
        'densenet201', 
        'densenet161',
        'efficientnet-b0',
        'vgg16',
        'vgg19'
    ])

@dataclass
class ModuleConfig:
    reduction: int = 16
    ratio: int = 8
    kernel_size: int = 3
    dropout_rt: float = 0.1
    activation: type = nn.ReLU

@dataclass
class TransformsConfig:
    train_transforms = transforms.Compose([
        BandNormalisation(
            mean=[DatasetConfig.band_stats["mean"][band] for band in DatasetConfig.all_bands],
            std=[DatasetConfig.band_stats["std"][band] for band in DatasetConfig.all_bands]
        )
    ])

    val_transforms = transforms.Compose([
        transforms.CenterCrop(120),
        BandNormalisation(
            mean=[DatasetConfig.band_stats["mean"][band] for band in DatasetConfig.all_bands],
            std=[DatasetConfig.band_stats["std"][band] for band in DatasetConfig.all_bands]
        )
    ])

    test_transforms = transforms.Compose([
        transforms.CenterCrop(120),
        BandNormalisation(
            mean=[DatasetConfig.band_stats["mean"][band] for band in DatasetConfig.all_bands],
            std=[DatasetConfig.band_stats["std"][band] for band in DatasetConfig.all_bands]
        )
    ])



In [None]:
class BigEarthNetDatasetTIF(Dataset):
    def __init__(self, *, df, root_dir, transforms=None, is_test=False, selected_bands=None, metadata_csv=None):
        self.df = df
        self.root_dir = root_dir
        self.transforms = transforms
        self.is_test = is_test
        self.selected_bands = selected_bands if selected_bands is not None else DatasetConfig.rgb_bands
        self.metadata = metadata_csv

        self.image_paths = list(Path(root_dir).rglob("*.tif"))
        self.patch_to_labels = dict(zip(self.metadata['patch_id'], self.metadata['labels']))
        self.image_paths = list(Path(root_dir).rglob("*.tif"))

        self.selected_band_indices = get_band_indices(self.selected_bands, DatasetConfig.all_bands)

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        with rasterio.open(image_path) as src:
            image = src.read()  
            image = image[self.selected_band_indices, :, :]
        
        # Image is convered to a tensor before applying transforms
        image = torch.tensor(image, dtype=torch.float32)

        if self.transforms:
            image = self.transforms(image)

        label = self.get_label(image_path)

        return image, label

    def get_label(self, img_path):
        img_path = Path(img_path) 
        patch_id = img_path.stem
        labels = self.patch_to_labels.get(patch_id, None)

        if labels is None:
            return torch.zeros(DatasetConfig.num_classes)  
    
        if isinstance(labels, str):
            cleaned_labels = labels.replace(" '", ", '").replace("[", "[").replace("]", "]")
            labels =  ast.literal_eval(cleaned_labels)
        
        encoded = encode_label(labels)
        return encoded

In [None]:
# Data module for BigEarthNet dataset
class BigEarthNetTIFDataModule(pl.LightningDataModule):
    def __init__(self, bands=None, dataset_dir=None, metadata_csv=None):
        super().__init__()
        self.bands = bands
        self.dataset_dir = dataset_dir
        self.metadata_csv = metadata_csv

    def setup(self, stage=None):
        train_df = self.metadata_csv[self.metadata_csv['split'] == 'train']
        val_df = self.metadata_csv[self.metadata_csv['split'] == 'validation']
        test_df = self.metadata_csv[self.metadata_csv['split'] == 'test']

        self.train_dataset = BigEarthNetDatasetTIF(df=train_df, root_dir=self.dataset_dir, transforms=TransformsConfig.train_transforms, selected_bands=self.bands, metadata_csv=self.metadata_csv)
        self.val_dataset = BigEarthNetDatasetTIF(df=val_df, root_dir=self.dataset_dir, transforms=TransformsConfig.val_transforms, selected_bands=self.bands, metadata_csv=self.metadata_csv)
        self.test_dataset = BigEarthNetDatasetTIF(df=test_df, root_dir=self.dataset_dir, transforms=TransformsConfig.test_transforms, selected_bands=self.bands, metadata_csv=self.metadata_csv)
        
    def train_dataloader(self):
        dataloader = DataLoader(self.train_dataset, batch_size=ModelConfig.batch_size, num_workers=ModelConfig.num_workers, pin_memory=True, shuffle=True, persistent_workers=True)
        return dataloader

    def val_dataloader(self):
        dataloader = DataLoader(self.val_dataset, batch_size=ModelConfig.batch_size,  num_workers=ModelConfig.num_workers, pin_memory=True,  persistent_workers=True)
        return dataloader

    def test_dataloader(self):
        dataloader = DataLoader(self.test_dataset, batch_size=ModelConfig.batch_size,  num_workers=ModelConfig.num_workers, pin_memory=True,  persistent_workers=True)
        return dataloader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BigEarthNetResNet18ModelTIF(pl.LightningModule):
    def __init__(self, class_weights, num_classes, in_channels, model_weights):
        super(BigEarthNetResNet18ModelTIF, self).__init__()
        self.model = models.resnet18(weights=model_weights)

        original_conv1 = self.model.conv1
        self.model.conv1 = nn.Conv2d(
            in_channels=in_channels,  
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias,
        )

        # Initialize weights for the new channels (copy pretrained weights for 3 channels and random for the rest)
        nn.init.kaiming_normal_(self.model.conv1.weight, mode='fan_out', nonlinearity='relu')

        # Modify the final layer to output 19 classes
        self.model.fc = nn.Linear(self.model.fc.in_features, DatasetConfig.num_classes)

        self.sigmoid = nn.Sigmoid()

        self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.class_weights) # Define loss function
        # Passing the model to the GPU
        self.model.to(device)

        # Accuracy metrics
        self.train_acc = MultilabelAccuracy(num_labels=DatasetConfig.num_classes)
        self.val_acc = MultilabelAccuracy(num_labels=DatasetConfig.num_classes)
        self.test_acc = MultilabelAccuracy(num_labels=DatasetConfig.num_classes)

        # Recall metrics
        self.train_recall = MultilabelRecall(num_labels=DatasetConfig.num_classes)
        self.val_recall = MultilabelRecall(num_labels=DatasetConfig.num_classes)
        self.test_recall = MultilabelRecall(num_labels=DatasetConfig.num_classes)

        # Precision metrics
        self.train_precision = MultilabelPrecision(num_labels=DatasetConfig.num_classes)
        self.val_precision = MultilabelPrecision(num_labels=DatasetConfig.num_classes)
        self.test_precision = MultilabelPrecision(num_labels=DatasetConfig.num_classes)

        # F1 Score metrics
        self.train_f1 = MultilabelF1Score(num_labels=DatasetConfig.num_classes)
        self.val_f1 = MultilabelF1Score(num_labels=DatasetConfig.num_classes)
        self.test_f1 = MultilabelF1Score(num_labels=DatasetConfig.num_classes)

        #torch.summary(self.model, (DatasetConfig.band_channels, ModelConfig.img_size, ModelConfig.img_size))

    def forward(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=ModelConfig.learning_rate)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=ModelConfig.lr_factor, patience=ModelConfig.lr_patience)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',  
                'interval': 'epoch',
                'frequency': 1
            }
        }

    def cross_entropy_loss(self, logits, labels):
        return self.criterion(logits, labels)

    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, 'train')

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, 'test')

    def _step(self, batch, batch_idx, phase):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        acc = getattr(self, f'{phase}_acc')(logits, y)
        recall = getattr(self, f'{phase}_recall')(logits, y)
        f1 = getattr(self, f'{phase}_f1')(logits, y)
        precision = getattr(self, f'{phase}_precision')(logits, y)

        self.log(f'{phase}_loss', loss, on_epoch=True, prog_bar=True)
        self.log(f'{phase}_acc', acc, on_epoch=True, prog_bar=True)
        self.log(f'{phase}_recall', recall, on_epoch=True, prog_bar=True)
        self.log(f'{phase}_f1', f1, on_epoch=True, prog_bar=True)
        self.log(f'{phase}_precision', precision, on_epoch=True, prog_bar=True)

        return loss

    def on_epoch_end(self, phase):
        self.log(f'{phase}_acc_epoch', getattr(self, f'{phase}_acc').compute())
        self.log(f'{phase}_recall_epoch', getattr(self, f'{phase}_recall').compute())
        self.log(f'{phase}_f1_epoch', getattr(self, f'{phase}_f1').compute())
        self.log(f'{phase}_precision_epoch', getattr(self, f'{phase}_precision').compute())

        # Reset metrics
        getattr(self, f'{phase}_acc').reset()
        getattr(self, f'{phase}_recall').reset()
        getattr(self, f'{phase}_f1').reset()
        getattr(self, f'{phase}_precision').reset()

   
    def print_summary(self, input_size, filename):
        current_directory = os.getcwd()
        save_dir = os.path.join(current_directory, 'FYPProjectMultiSpectral', 'models', 'Architecture', filename)
        save_path = os.path.join(save_dir, f'{filename}_summary.txt)')
        os.makedirs(save_dir, exist_ok=True)  

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        # Create a dummy input tensor with the specified input size
        dummy_input = torch.zeros(1, *input_size).to(device)

        # Redirect the summary output to a file
        with open(save_path, 'w') as f:
            with redirect_stdout(f):
                summary(self.model, input_size)

    def visualize_model(self, input_size, model_name):
        current_directory = os.getcwd()
        save_path = os.path.join(current_directory, 'FYPProjectMultiSpectral', 'models', 'Architecture', model_name)
        os.makedirs(save_path, exist_ok=True)  

        # Move the model to the correct device
        self.model.to(device)

        # Create a random tensor input based on the input size
        x = torch.randn(1, *input_size).to(device)  
        # Pass the tensor through the model
        y = self.model(x)

        # Create the visualization and save it at the specified path
        file_path = os.path.join(save_path, f'{model_name}')
        make_dot(y, params=dict(self.model.named_parameters())).render(file_path)


In [None]:
metadata_path = DatasetConfig.metadata_paths["0.5"]
metadata_csv = pd.read_csv(metadata_path)

dataset_dir = DatasetConfig.dataset_paths["0.5"]

class_weights, class_weights_array = calculate_class_weights(metadata_csv)
class_weights = class_weights_array

bands = DatasetConfig.all_bands

# Initialize the data module
data_module = BigEarthNetTIFDataModule(bands=bands, dataset_dir=dataset_dir, metadata_csv=metadata_csv)
data_module.setup(stage=None)

In [None]:
train_dataloader = data_module.train_dataloader()
print(f"Number of training batches: {len(train_dataloader)}")

for batch in train_dataloader:
    x, y = batch
    print(f"Batch Shape: {x.shape}")
    print(f"Labels Shape: {y.shape}")
    break

# Print a few sample data points
for batch_idx, (inputs, labels) in enumerate(train_dataloader):
    print(f"Batch {batch_idx + 1}")
    print(f"Inputs: {inputs}")
    print(f"Labels: {labels}")
    break  

# Verify the distribution of classes in the labels
all_labels = []
for batch_idx, (inputs, labels) in enumerate(train_dataloader):
    all_labels.extend(labels.numpy())

# Convert to a numpy array and print the class distribution
all_labels = np.array(all_labels)
unique, counts = np.unique(all_labels, return_counts=True)
class_distribution = dict(zip(unique, counts))
print(f"Class distribution: {class_distribution}")

In [None]:
model = BigEarthNetResNet18ModelTIF(class_weights=class_weights, num_classes=DatasetConfig.num_classes, in_channels=12, model_weights=ResNet18_Weights.DEFAULT)

In [None]:
print(model)

# Create a sample input tensor
sample_input = torch.randn(1, 12, 120, 120)  

# Perform a forward pass
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    sample_output = model(sample_input)

# Print the output shape
print(f"Output shape: {sample_output.shape}")

In [None]:
# Register hooks to capture activations
layer_names = []
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer.register_forward_hook(get_activation(name))
        layer_names.append(name)

In [None]:
model_name = 'resnet18'
weights = 'ResNet18_Weights.DEFAULT'
selected_bands = 'all_bands'
selected_dataset = '0.5'

In [None]:
log_dir = r'C:\Users\isaac\Desktop\experiments\logs'
logger = TensorBoardLogger(log_dir, name=f"{model_name}_{weights}_{selected_bands}_experiment_{selected_dataset}")


checkpoint_dir = fr'C:\Users\isaac\Desktop\experiments\checkpoints\{model_name}_{weights}_{selected_bands}_{selected_dataset}'

# Checkpoint callback for val_loss
checkpoint_callback_loss = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename=f'{{epoch:02d}}-{{val_loss:.2f}}',
    save_top_k=1,
    verbose=False,
    monitor='val_loss',
    mode='min'
)

# Checkpoint callback for val_acc
checkpoint_callback_acc = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename=f'{{epoch:02d}}-{{val_acc:.2f}}',
    save_top_k=1,
    verbose=False,
    monitor='val_acc',
    mode='max'
)

final_checkpoint = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename=f'final',
    save_last=True
)

# Early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=ModelConfig.patience,
    verbose=True,
    mode='min'
)

# Model Training with custom callbacks
trainer = pl.Trainer(
    default_root_dir=checkpoint_dir,
    max_epochs=ModelConfig.num_epochs,
    logger=logger,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    precision='16-mixed',
    log_every_n_steps=1,
    accumulate_grad_batches=2,
    callbacks=[checkpoint_callback_loss, checkpoint_callback_acc, final_checkpoint, early_stopping]
)

trainer.fit(model, data_module)

In [None]:
# Visualize activations after each epoch
visualize_activations(layer_names, activations)

In [None]:
checkpoint_path = checkpoint_callback_loss.final_checkpoint
model = BigEarthNetTIFDataModule.load_from_checkpoint(checkpoint_path, class_weights=class_weights, num_classes=num_classes, in_channels=in_channels, model_weights=model_weights)
model.eval()
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    precision='16-mixed'  
)
# Run test
trainer.test(model, datamodule=data_module)

In [None]:
# Collect predictions and true labels
all_preds = []
all_labels = []

for batch in tqdm(data_module.test_dataloader(), desc="Processing Batches"):
    inputs, labels = batch
    inputs = inputs.to(model.device)
    labels = labels.to(model.device)

    with torch.no_grad():
        logits = model(inputs)  
        #print(f"Raw logits: {logits}")  
        preds = torch.sigmoid(logits) > 0.5
        #print(f"Sigmoid outputs: {torch.sigmoid(logits)}")  

    all_preds.extend(preds.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())