# Eniviroment setup

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16
from torchvision.transforms import transforms
from typing import Optional

from dataset.dataset_info import ClassifierDatasetInfo
from dataset.dataset_type import DatasetType
from dataset.STARCOP_dataset import STARCOPDataset

from models.Tools.FilesHandler.model_files_handler import ModelFilesHandler
from models.Tools.Measures.measure_tool_factory import MeasureToolFactory
from models.Tools.Measures.model_type import ModelType

import os
os.chdir(r"C:\Users\mpilc\Desktop\Studia\Thesis\Repozytoria\Thesis")

## Setup datasets
STARCOPDataset is custom class that derives torch.utils.data.Dataset class. It's defined in dataset module.

In [2]:
def setup_dataloaders(data_path: str = r"data", batch_size: int = 32, train_type = DatasetType.EASY_TRAIN):
    train_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=train_type,
        image_info_class=ClassifierDatasetInfo,
        crop_size=2
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = STARCOPDataset(
        data_path=data_path,
        data_type=DatasetType.TEST,
        image_info_class=ClassifierDatasetInfo,
        crop_size=2
    )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, test_dataloader

## Setup models

### Model class

In [3]:
class CustomViT(nn.Module):
    def __init__(self, num_channels=9, num_classes=2):
        super(CustomViT, self).__init__()
        # Load pre-trained ViT model
        self.vit = vit_b_16(weights=None)  # Use pretrained weights if desired

        # Modify the input embedding layer to accept `num_channels`
        self.vit.conv_proj = nn.Conv2d(num_channels, self.vit.conv_proj.out_channels,
                                       kernel_size=self.vit.conv_proj.kernel_size,
                                       stride=self.vit.conv_proj.stride,
                                       padding=self.vit.conv_proj.padding,
                                       bias=(self.vit.conv_proj.bias is not None))

        # Modify the classifier head for binary classification
        self.vit.heads = nn.Sequential(
            nn.Linear(self.vit.heads.head.in_features, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.vit(x)

### Prepare models 

In [4]:
def setup_model(model: nn.Module, lr: float, device: str):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()  # Binary classification
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, criterion, optimizer

## Prepare training function

In [5]:
def train(criterion, device, epochs, model, optimizer, dataloader, transform: Optional[transforms] = None, log_batches: bool = False):
    model.train()
    for epoch in range(epochs):  # Adjust the number of epochs
        print(f"Epoch: {epoch}")
        running_loss = 0.0
        for batch_id, (images, mag1c, labels) in enumerate(dataloader):  # Assume a PyTorch DataLoader is set up
            optimizer.zero_grad()

            input_image = torch.cat((images, mag1c), dim=1)
            labels = labels.long().to(device)

            outputs = model((transform(input_image) if transform else  input_image).to(device))

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if log_batches and (batch_id + 1) % 10 == 0:
                print(f"Batch: {batch_id + 1}, Loss: {running_loss / (batch_id + 1)}")

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")

## Prepare evaluate function

In [6]:
def evaluate(criterion, device, model, dataloader, measurer, transform: Optional[transforms] = None):
    model.eval()
    all_predictions = []
    all_labels = []
    running_loss = 0.0

    for batch_id, (images, mag1c, labels) in enumerate(dataloader):
        input_image = torch.cat((images, mag1c), dim=1)
        labels = labels.long().to(device)

        outputs = model((transform(input_image) if transform else  input_image).to(device))
        predictions = torch.argmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        all_predictions.append(predictions.cpu().detach())
        all_labels.append(labels.cpu().detach())

    measures = measurer.compute_measures(torch.cat(all_predictions), torch.cat(all_labels))
    print(f"Validation loss: {running_loss / len(dataloader)}.\nMeasures:\n{measures}")
    return measures

# Train model 

In [7]:
epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4

train_dataloader, test_dataloader = setup_dataloaders()
from models.Transformer.VIT.model import CustomViT

model = CustomViT()
model, criterion, optimizer = setup_model(model, lr, device)
model_handler = ModelFilesHandler()
measurer = MeasureToolFactory.get_measure_tool(ModelType.TRANSFORMER)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.Normalize(mean=[0.5] * 9, std=[0.5] * 9)  # Normalize for 9 channels
])

train(criterion, device, epochs, model, optimizer, train_dataloader, transform, log_batches=True)
measures = evaluate(criterion, device, model, test_dataloader,measurer, transform)


Epoch: 0


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Batch: 10, Loss: 0.6877612948417664
Batch: 20, Loss: 0.6608143121004104
Batch: 30, Loss: 0.6587702651818593
Batch: 40, Loss: 0.6516647815704346
Batch: 50, Loss: 0.6451245200634003
Batch: 60, Loss: 0.6381875624259313
Batch: 70, Loss: 0.6256400623491832
Epoch 1, Loss: 0.6256400623491832
Validation loss: 0.5638946010622867.
Measures:
         TP       FP        FN        TN  Precision  Sensitivity  Specificity  \
0  0.255132  0.09824  0.183284  0.463343    0.72199     0.581938     0.825064   

        NPV       FPR  Accuracy   F-Score      IoU       MCC       AUC  \
0  0.716552  0.174934  0.718474  0.644442  0.47541  0.422479  0.703503   

         CI  
0  0.025378  


# Save model

In [8]:
model_handler.save_model(
    model=model,
    metrics = measures,
    model_type=ModelType.TRANSFORMER_CLASSIFIER,
    epoch=epochs,
)

'trained_models\\model_transformer_classifier_2024_11_30_10_22_44.pickle'