# Environment setup

In [16]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from typing import Optional
from torchvision.transforms import transforms

from dataset.dataset_info import ClassifierDatasetInfo
from dataset.dataset_type import DatasetType
from dataset.STARCOP_dataset import STARCOPDataset

from models.Tools.FilesHandler.model_files_handler import ModelFilesHandler
from models.Tools.Measures.measure_tool_factory import MeasureToolFactory
from models.Tools.Measures.model_type import ModelType

import os
os.chdir(r"D:\Projects\studia\polsl_ssi_1\MethaneDetection\Thesis")

## Setup datasets
STARCOPDataset is custom class that derives torch.utils.data.Dataset class. It's defined in dataset module.

In [17]:
def setup_dataloaders(data_path: str = r"data", batch_size: int = 32, train_type = DatasetType.EASY_TRAIN):
    train_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=train_type,
        image_info_class=ClassifierDatasetInfo,
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=DatasetType.TEST,
        image_info_class=ClassifierDatasetInfo,
    )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, test_dataloader

## Setup models

### Model class

In [18]:
import torch
import torch.nn as nn


class MethaNetClassifier(nn.Module):
    def __init__(self, in_channels:int = 9, num_classes:int = 2):
        super(MethaNetClassifier, self).__init__()
        self.pre_conv = nn.Sequential(
            nn.Conv2d(in_channels//1, in_channels//2, kernel_size=1),
            nn.Conv2d(in_channels//2, in_channels//4, kernel_size=1),
            nn.Conv2d(in_channels//4, in_channels//8, kernel_size=1),
        )

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=2, stride=1, padding=0),  # Input channels = 8
            nn.ReLU(),

            nn.Conv2d(6, 12, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(12, 16, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Dropout(p=0.2),

            nn.Conv2d(16, 16, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 61 * 61, 64),  # Adjust dimensions for 512x512 input
            nn.ReLU(),

            nn.Linear(64, 32),
            nn.ReLU(),

            nn.Linear(32, num_classes),
            nn.Softmax(dim=1)  # Softmax for class probabilities
        )

    def forward(self, x):
        x = self.pre_conv(x)
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


### Prepare models

In [19]:
def setup_model(model: nn.Module, lr: float, device: str):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()  # Binary classification
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, criterion, optimizer

## Prepare training function

In [20]:
def train_cnn(criterion, device, epochs, model, optimizer, dataloader, transform: Optional[transforms] = None, log_batches: bool = False):
    model.train()
    for epoch in range(epochs):  # Adjust the number of epochs
        running_loss = 0.0
        for batch_id, (images, mag1c, labels) in enumerate(dataloader):  # Assume a PyTorch DataLoader is set up
            optimizer.zero_grad()

            input_image = torch.cat((images, mag1c), dim=1)
            labels = labels.long().to(device)

            outputs = model((transform(input_image) if transform else  input_image).to(device))

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")

## Prepare evaluate function

In [21]:
def evaluate_cnn(criterion, device, model, dataloader, measurer, transform: Optional[transforms] = None):
    model.eval()
    all_predictions = []
    all_labels = []
    running_loss = 0.0

    for batch_id, (images, mag1c, labels) in enumerate(dataloader):
        input_image = torch.cat((images, mag1c), dim=1)
        labels = labels.long().to(device)

        outputs = model((transform(input_image) if transform else  input_image).to(device))
        predictions = torch.argmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        all_predictions.append(predictions.cpu().detach())
        all_labels.append(labels.cpu().detach())

    measures = measurer.compute_measures(torch.cat(all_predictions), torch.cat(all_labels))
    print(f"Validation loss: {running_loss / len(dataloader)}.\nMeasures:\n{measures}")
    return measures

# Train model

In [22]:
epochs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-5

train_dataloader, test_dataloader = setup_dataloaders(batch_size=16)
model = MethaNetClassifier()
model, criterion, optimizer = setup_model(model, lr, device)

model_handler = ModelFilesHandler()
measurer = MeasureToolFactory.get_measure_tool(ModelType.CNN)

train_cnn(criterion, device, epochs, model, optimizer, train_dataloader, log_batches=True)
measures = evaluate_cnn(criterion, device, model, test_dataloader, measurer)

Epoch 1, Loss: 0.6540219289915902
Epoch 2, Loss: 0.555686570916857
Epoch 3, Loss: 0.48661099672317504
Epoch 4, Loss: 0.4395524263381958
Epoch 5, Loss: 0.408256835596902
Epoch 6, Loss: 0.3862720421382359
Epoch 7, Loss: 0.36930905069623676
Epoch 8, Loss: 0.3594409908567156
Epoch 9, Loss: 0.34937363607542854
Epoch 10, Loss: 0.3438535213470459
Validation loss: 0.5237618982791901.
Measures:
         TP        FP        FN        TN  Precision  Sensitivity  \
0  0.263158  0.002924  0.225146  0.508772   0.989011     0.538922   

   Specificity       NPV       FPR        Accuracy   F-Score             IoU  \
0     0.994286  0.693227  0.005714  tensor(0.7719)  0.697674  tensor(0.5357)   

        MCC       AUC          CI  
0  0.603137  0.766604  (1.0, 1.0)  


# Save model

In [23]:
model_handler.save_model(
    model=model,
    metrics = measures,
    model_type=ModelType.CNN,
    epoch=epochs,
)

'trained_models\\model_cnn_2024_12_12_13_47_08.pickle'