In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch.optim.lr_scheduler import StepLR
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import numpy as np
from PIL import Image

### TODO: add dataloader and maybe collator class for loading the images from the fine-tuning dataset

In [None]:


# Object Detection Model
object_detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Freeze certain layers for fine-tuning
for layer in object_detection_model.parameters():
    layer.requires_grad = False

# Modify the model for the new number of classes (3 + background)
in_features = object_detection_model.roi_heads.box_predictor.cls_score.in_features
object_detection_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=4)

# Image Classification Model
image_classification_model = torchvision.models.resnet50(pretrained=True)

# Freeze all layers except the last one
for param in image_classification_model.parameters():
    param.requires_grad = False
image_classification_model.fc.requires_grad = True

# Replace the pre-trained head with a new one
num_classes = 20  # Placeholder because I forgot how many we were going to use
image_classification_model.fc = nn.Linear(image_classification_model.fc.in_features, num_classes)


In [None]:
# Assuming train_loader and val_loader are properly set up
optimizer = torch.optim.SGD(object_detection_model.parameters(), lr=0.005, momentum=0.9)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    object_detection_model.train()
    train_loss = 0.0

    for images, targets in train_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = object_detection_model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()

        train_loss += losses.item()

    # Validation phase
    object_detection_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = object_detection_model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()

    # Learning rate scheduler step
    scheduler.step()

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Save the best model
        torch.save(object_detection_model.state_dict(), 'best_object_detection_model.pth')

    print(f'Epoch {epoch}, Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')


In [None]:
optimizer = torch.optim.SGD(image_classification_model.parameters(), lr=0.001, momentum=0.9)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    image_classification_model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = image_classification_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validation phase
    image_classification_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = image_classification_model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    # Learning rate scheduler step
    scheduler.step()

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Save the best model
        torch.save(image_classification_model.state_dict(), 'best_image_classification_model.pth')

    print(f'Epoch {epoch}, Train Loss: {running_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')


In [None]:
# Load fine-tuned object detection model
object_detection_model.load_state_dict(torch.load('best_object_detection_model.pth'))
object_detection_model.to(device)
object_detection_model.eval()

# Load fine-tuned image classification model
image_classification_model.load_state_dict(torch.load('best_image_classification_model.pth'))
image_classification_model.to(device)
image_classification_model.eval()

# Define the transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  #  Resize code, as I will resize the data to this in the data class
    transforms.ToTensor(),
    # Maybe going to add normalization (only if I end up using it for training)
])

# Load the image and apply transforms
image_path = "flag.jpg"  # Path image should be here
image = Image.open(image_path)
image = transform(image).unsqueeze(0)
image = image.to(device)

# Detect objects in the image
object_detection_results = object_detection_model([image])


In [None]:
# Extract flag and map image patches
for detection in object_detection_results[0]['boxes']:
    labels = object_detection_results[0]['labels']
    scores = object_detection_results[0]['scores']

    for i, box in enumerate(detection):
        if scores[i] > 0.5:  # Threshold to filter out low-confidence detections
            if labels[i] == 1:  # placeholder position for flag
                flag_image_patch = image[:, int(box[1]):int(box[3]), int(box[0]):int(box[2])].unsqueeze(0)
                flag_image_patches.append(flag_image_patch)
            elif labels[i] == 2:  # placeholder position for map
                map_image_patch = image[:, int(box[1]):int(box[3]), int(box[0]):int(box[2])].unsqueeze(0)
                map_image_patches.append(map_image_patch)


In [None]:
# Classify each flag image patch
flag_labels = []
for flag_image_patch in flag_image_patches:
    flag_image_patch = flag_image_patch.to(device)
    flag_predictions = image_classification_model(flag_image_patch)
    flag_prediction = torch.argmax(flag_predictions, dim=1).item()
    flag_labels.append(flag_prediction)

# Similar loop for map image patches
map_labels = []
for map_image_patch in map_image_patches:
    map_image_patch = map_image_patch.to(device)
    map_predictions = image_classification_model(map_image_patch)
    map_prediction = torch.argmax(map_predictions, dim=1).item()
    map_labels.append(map_prediction)
