# Environment setup

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset.dataset_info import ClassifierDatasetInfo, MMClassifierDatasetInfo
from dataset.dataset_type import DatasetType
from dataset.STARCOP_dataset import STARCOPDataset

from models.Tools.FilesHandler.model_files_handler import ModelFilesHandler
from models.Tools.Measures.measure_tool_factory import MeasureToolFactory
from models.Tools.Measures.model_type import ModelType as MeasuresModelType

import os
import sys
os.chdir(r"D:\Projects\studia\polsl_ssi_1\MethaneDetection\Thesis")

## Setup datasets
STARCOPDataset is custom class that derives torch.utils.data.Dataset class. It's defined in dataset module.  

In [2]:
from typing import Type
from dataset.dataset_info import DatasetInfo


def setup_dataloaders(
        data_path: str = r"data",
        batch_size: int = 32,
        train_type = DatasetType.EASY_TRAIN,
        image_info_class: Type[DatasetInfo] = ClassifierDatasetInfo,
        crop_size: int = 1
):
    train_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=train_type,
        image_info_class=image_info_class,
        crop_size=crop_size
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=DatasetType.TEST,
        image_info_class=image_info_class,
        crop_size=crop_size
    )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size // 2, shuffle=True)

    return train_dataloader, test_dataloader

## Setup models

### Model class

In [3]:
from models.Transformer.MethaneMapper.Segmentation.bbox_prediction import BBoxPrediction
from models.Transformer.MethaneMapper.Segmentation.segmentation import BoxAndMaskPredictor
from models.Transformer.MethaneMapper.Classification.classification import ClassifierPredictor
from models.Transformer.MethaneMapper.Transformer.hyperspectral_decoder import HyperspectralDecoder
from models.Transformer.MethaneMapper.Transformer.query_refiner import QueryRefiner
from models.Transformer.MethaneMapper.Transformer.encoder import Encoder
from models.Transformer.MethaneMapper.Transformer.position_encoding import PositionalEncodingMM
from models.Transformer.MethaneMapper.SpectralFeatureGenerator.spectral_feature_generator import \
    SpectralFeatureGenerator
from models.Transformer.MethaneMapper.Backbone.backbone import Backbone
from models.Transformer.MethaneMapper.model_type import ModelType

class TransformerModel(nn.Module):
    """
    TODO docs, verification, tests
    """

    def __init__(self,
                 d_model: int = 256,
                 backbone_out_channels: int = 2048,
                 image_height: int = 512,
                 image_width: int = 512,
                 attention_heads: int = 8,
                 n_encoder_layers: int = 6,
                 n_decoder_layers: int = 6,
                 n_queries: int = 100,
                 threshold: float = 0.5,
                 model_type: ModelType = ModelType.CLASSIFICATION,
                 ):
        super(TransformerModel, self).__init__()

        self.d_model = d_model

        self.backbone = Backbone(d_model=d_model, rgb_channels=3, swir_channels=5, out_channels=backbone_out_channels)
        self.spectral_feature_generator = SpectralFeatureGenerator(d_model=d_model)

        self.positional_encoding = PositionalEncodingMM(
            d_model=d_model
        )
        self.encoder = Encoder(d_model=d_model, n_heads=attention_heads, num_layers=n_encoder_layers)

        self.query_refiner = QueryRefiner(d_model=d_model, num_heads=attention_heads, num_queries=n_queries)
        self.decoder = HyperspectralDecoder(d_model=d_model, n_heads=attention_heads, num_layers=n_decoder_layers)


        self.head = None
        match model_type:
            case ModelType.CLASSIFICATION:
                self.head = ClassifierPredictor(
                    num_classes=2,
                    embedding_dim=d_model,
                )
            case ModelType.SEGMENTATION:
                self.head = BoxAndMaskPredictor(
                    result_width=image_width,
                    result_height=image_height,
                    fpn_channels=backbone_out_channels,
                    embedding_dim=d_model,
                )
            case ModelType.ONLY_BBOX:
                self.head = BBoxPrediction(d_model=d_model)

    def forward(self, image: torch.Tensor, filtered_image: torch.Tensor) -> tuple[
        torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        TODO docs, tests
        """
        # get image size
        batch_size, channels, height, width = image.shape

        f_comb_proj, f_comb = self.backbone(image)

        positional_encoding = self.positional_encoding(f_comb)[0].expand(batch_size, -1, -1, -1)

        f_mc = self.spectral_feature_generator(filtered_image)
        f_mc = f_mc.permute(0, 2, 3, 1)

        q_ref = self.query_refiner(f_mc)
        f_e = self.encoder((f_comb_proj + positional_encoding).flatten(2).permute(0, 2, 1))


        e_out = self.decoder(
            (f_e.permute(0, 2, 1).view(batch_size, -1, int(height / 32), int(width / 32)) + positional_encoding).flatten(2).permute(0, 2, 1),
            q_ref
        )

        result = self.head(e_out, f_e)

        return result

### Prepare models 

In [4]:
def setup_model(model: nn.Module, lr: float, device: str):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()  # Binary classification
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, criterion, optimizer

## Prepare training function

In [5]:
def print_progress_bar(percentage, loss):
    bar_length = 50  # Length of the progress bar
    filled_length = int(bar_length * percentage // 100)
    bar = '=' * filled_length + '-' * (bar_length - filled_length)
    sys.stdout.write(f"\r[{bar}] {percentage:.2f}% [Loss: {loss:.6f}]")
    sys.stdout.flush()

In [6]:
def train(criterion, device, epochs, model, optimizer, dataloader, model_handler,  log_batches: bool = False):
    model.train()
    for epoch in range(epochs):  # Adjust the number of epochs
        running_loss = 0.0
        for batch_id, (images, mag1c, filtered_image, labels) in enumerate(dataloader):  # Assume a PyTorch DataLoader is set up
            optimizer.zero_grad()

            input_image = torch.cat((images, mag1c), dim=1).to(device)
            filtered_image = filtered_image.to(device)
            labels = labels.long().to(device)

            outputs = model(input_image, filtered_image)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if log_batches and (batch_id + 1) % 10 == 0:
                print_progress_bar(batch_id / len(dataloader) * 100, running_loss / (batch_id + 1))

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
        model_handler.save_raw_model(model)

## Prepare evaluate function

In [7]:
def evaluate(criterion, device, model, dataloader, measurer):
    model.eval()
    all_predictions = []
    all_labels = []
    running_loss = 0.0

    for batch_id, (images, mag1c, filtered_image, labels) in enumerate(dataloader):
        input_image = torch.cat((images, mag1c), dim=1).to(device)
        filtered_image = filtered_image.to(device)
        labels = labels.long().to(device)

        outputs = model(input_image, filtered_image)

        predictions = torch.argmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        all_predictions.append(predictions.cpu().detach())
        all_labels.append(labels.cpu().detach())

    measures = measurer.compute_measures(torch.cat(all_predictions), torch.cat(all_labels))
    print(f"Validation loss: {running_loss / len(dataloader)}.\nMeasures:\n{measures}")
    return measures

# Train model 

In [8]:
epochs = 15
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4

from models.Transformer.MethaneMapper.model import TransformerModel

train_dataloader, test_dataloader = setup_dataloaders(
    batch_size=8,
    image_info_class=MMClassifierDatasetInfo,
    crop_size=1,
    train_type=DatasetType.TRAIN,
)
model = TransformerModel(
    n_queries=5,
    n_decoder_layers=5,
    n_encoder_layers=5,
    d_model=256,
)
model, criterion, optimizer = setup_model(model, lr, device)
model_handler = ModelFilesHandler()
measurer = MeasureToolFactory.get_measure_tool(MeasuresModelType.TRANSFORMER)

train(criterion, device, epochs, model, optimizer, train_dataloader, model_handler, log_batches=True)
measures = evaluate(criterion, device, model, test_dataloader, measurer)


Validation loss: 0.6195664371006451.
Measures:
         TP        FP        FN        TN  Precision  Sensitivity  \
0  0.280702  0.020468  0.207602  0.491228   0.932036     0.574849   

   Specificity       NPV   FPR  Accuracy   F-Score       IoU      MCC  \
0     0.959998  0.702928  0.04  0.771929  0.711109  0.551724  0.58276   

        AUC        CI  
0  0.767425  0.048693  


# Save model

In [9]:
model_handler.save_model(
    model=model,
    metrics = measures,
    model_type=MeasuresModelType.TRANSFORMER_CLASSIFIER,
    epoch=epochs,
)

'trained_models\\model_transformer_classifier_2024_12_16_07_53_32.pickle'