# Transformer Model

## Imports

In [None]:
# Add Utils to path
import sys
sys.path.append('../Utils')  # Adds higher directory to python modules path.

# Utils
from image_enhancement_functions import histogram_equalization, clahe, color_balance_adjustment, min_max_contrast_enhancement
from custom_image_dataset import CustomImageDataset

# Pytorch
import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader

## Dataset

## Model

In [None]:
class LocalizationTransformer(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  # Remove the final layer
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=2048, nhead=8), num_layers=1
        )
        self.fc = nn.Linear(2048, num_classes * 4)  # 4 coordinates for each class

    def forward(self, x):
        x = self.resnet(x)
        x = self.transformer(x.unsqueeze(0))
        x = self.fc(x.squeeze(0))
        return x.view(x.size(0), -1, 4)  # Reshape to (batch_size, num_classes, 4)

## Train

In [None]:
# Create the model
model = LocalizationTransformer(num_classes=len(dataset.classes))

# Create data loaders
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define a loss function and an optimizer
criterion = nn.SmoothL1Loss()  # Smooth L1 loss is commonly used for regression tasks
optimizer = torch.optim.Adam(model.parameters())

# Training loop
for epoch in range(10):  # Number of epochs
    for imgs, labels in train_loader:
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")
