# Environment setup

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from typing import Optional, Type

from transformers import DetrConfig, DetrForObjectDetection

from dataset.dataset_info import ClassifierDatasetInfo, DatasetInfo
from dataset.dataset_type import DatasetType
from dataset.STARCOP_dataset import STARCOPDataset

from models.Tools.FilesHandler.model_files_handler import ModelFilesHandler
from models.Tools.Measures.measure_tool_factory import MeasureToolFactory
from models.Tools.Measures.model_type import ModelType

import os
import sys
os.chdir(r"C:\Users\mpilc\Desktop\Studia\Thesis\Repozytoria\Thesis")

  from .autonotebook import tqdm as notebook_tqdm


## Setup datasets
STARCOPDataset is custom class that derives torch.utils.data.Dataset class. It's defined in dataset module.  

In [2]:
def setup_dataloaders(
        data_path: str = r"data",
        batch_size: int = 32,
        train_type = DatasetType.EASY_TRAIN,
        image_info_class: Type[DatasetInfo] = ClassifierDatasetInfo,
        crop_size: int = 1
):
    train_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=train_type,
        image_info_class=image_info_class,
        crop_size=crop_size
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=DatasetType.TEST,
        image_info_class=image_info_class,
        crop_size=crop_size
    )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, test_dataloader

## Setup models

### Model class

In [3]:
class CustomDetrForClassification(nn.Module):
    def __init__(self, detr_model_name="facebook/detr-resnet-50", num_channels=9, num_classes=2):
        super().__init__()

        # Load pre-trained DETR model
        config = DetrConfig.from_pretrained(detr_model_name)
        config.num_labels = num_classes  # Number of classification labels
        config.use_decoder = True  # Ensure the decoder is retained for processing queries
        config.output_hidden_states = True  # Ensure hidden states are returned
        self.detr = DetrForObjectDetection(config=config)

        # Modify the first convolutional layer of the backbone to accept 9 channels
        backbone = self.detr.model.backbone
        conv1 = backbone.conv_encoder.model.conv1
        new_conv1 = nn.Conv2d(
            in_channels=num_channels,
            out_channels=conv1.out_channels,
            kernel_size=conv1.kernel_size,
            stride=conv1.stride,
            padding=conv1.padding,
            bias=conv1.bias,
        )

        # Replace the original conv1 with the new one
        backbone.conv_encoder.model.conv1 = new_conv1

        # Freeze backbone layers except the first conv layer
        for name, param in backbone.named_parameters():
            if "conv1" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        # Add a classification head to process the outputs of the decoder
        hidden_size = config.d_model
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, pixel_values):
        # Pass inputs through DETR backbone and transformer
        outputs = self.detr.model(pixel_values)

        # Extract decoder output (shape: batch_size, num_queries, d_model)
        decoder_output = outputs.decoder_hidden_states[-1]

        # Apply classification head (average over all queries)
        logits = self.classifier(decoder_output.mean(dim=1))  # (batch_size, num_classes)

        return logits

### Prepare models 

In [4]:
def setup_model(model: nn.Module, lr: float, device: str):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()  # Binary classification
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, criterion, optimizer

## Prepare training function

In [5]:
def print_progress_bar(percentage, loss):
    bar_length = 50  # Length of the progress bar
    filled_length = int(bar_length * percentage // 100)
    bar = '=' * filled_length + '-' * (bar_length - filled_length)
    sys.stdout.write(f"\r[{bar}] {percentage:.2f}% [Loss: {loss:.6f}]")
    sys.stdout.flush()

In [6]:
def train(criterion, device, epochs, model, optimizer, dataloader, model_handler, log_batches: bool = False):
    model.train()
    for epoch in range(epochs):
        print(f"Epoch: {epoch}")
        running_loss = 0.0
        for batch_id, (images, mag1c, labels) in enumerate(dataloader):
            optimizer.zero_grad()

            input_image = torch.cat((images, mag1c), dim=1).to(device)
            labels = labels.long().to(device)

            outputs = model(input_image)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if log_batches and (batch_id + 1) % 10 == 0:
                print_progress_bar(batch_id / len(dataloader) * 100, running_loss / (batch_id + 1))

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
        model_handler.save_raw_model(model)

## Prepare evaluate function

In [7]:
def evaluate(criterion, device, model, dataloader, measurer):
    model.eval()
    all_predictions = []
    all_labels = []
    running_loss = 0.0

    for batch_id, (images, mag1c, labels) in enumerate(dataloader):
        input_image = torch.cat((images, mag1c), dim=1).to(device)
        labels = labels.long().to(device)

        outputs = model(input_image)
        predictions = torch.argmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        all_predictions.append(predictions.cpu().detach())
        all_labels.append(labels.cpu().detach())

    measures = measurer.compute_measures(torch.cat(all_predictions), torch.cat(all_labels))
    print(f"Validation loss: {running_loss / len(dataloader)}.\nMeasures:\n{measures}")
    return measures

# Train model 

In [8]:
epochs = 15
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4

# Override for file handler
from models.Transformer.DETR.model import CustomDetrForClassification


train_dataloader, test_dataloader = setup_dataloaders(batch_size=16, train_type=DatasetType.TRAIN)
model = CustomDetrForClassification()
model, criterion, optimizer = setup_model(model, lr, device)
model_handler = ModelFilesHandler()
measurer = MeasureToolFactory.get_measure_tool(ModelType.TRANSFORMER)

train(criterion, device, epochs, model, optimizer, train_dataloader, model_handler, log_batches = True)
measures = evaluate(criterion, device, model, test_dataloader, measurer)


Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Validation loss: 0.2711253989149224.
Measures:
         TP        FP        FN        TN  Precision  Sensitivity  \
0  0.375367  0.014663  0.111437  0.498534   0.962404     0.771083   

   Specificity       NPV       FPR  Accuracy   F-Score       IoU       MCC  \
0     0.971427  0.817306  0.028571  0.873899  0.856185  0.748538  0.760883   

        AUC        CI  
0  0.871256  0.051847  


# Save model

In [3]:
model_handler.save_model(
    model=model,
    metrics = measures,
    model_type=ModelType.DETR,
    epoch=epochs,
)

'trained_models\\model_detr_2024_12_15_11_35_17.pickle'
