In [15]:
import torch
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
from torchvision.datasets import CocoDetection
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models.detection.faster_rcnn import FasterRCNN_MobileNet_V3_Large_FPN_Weights
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Load the pre-trained model
model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT).to(device)

# Freeze the parameters of the base network
for param in model.backbone.parameters():
    param.requires_grad = False

# Define the transform
transform = T.Compose([
    T.Resize((640, 512)),  # Resize all images to have size 800x800
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load the training data
train_data = CocoDetection(root='./FLIR_ADAS_v2/images_thermal_train', annFile='./FLIR_ADAS_v2/images_thermal_train/coco.json', transform=transform)
# Load the validation and test data
print(len(train_data))
val_data = CocoDetection(root='./FLIR_ADAS_v2/images_thermal_val', annFile='./FLIR_ADAS_v2/images_thermal_val/coco.json', transform=transform)
test_data = CocoDetection(root='./FLIR_ADAS_v2/video_thermal_test', annFile='./FLIR_ADAS_v2/video_thermal_test/coco.json', transform=transform)

# Create data loaders
train_loader = DataLoader(train_data, batch_size=4, shuffle=True,num_workers=4, collate_fn=collate_fn)

val_loader = DataLoader(val_data, batch_size=4, shuffle=False,num_workers=4,collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=4, shuffle=False,num_workers=4,collate_fn=collate_fn)

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

writer = SummaryWriter()
scheduler = ReduceLROnPlateau(optimizer, 'min')
import matplotlib.pyplot as plt

epoch = 100
best_val_loss = float('inf')
epochs_no_improve = 0
n_epochs_stop = 10
print('Start Training')
# Train the model
for epoch in range(epoch):  # loop over the dataset multiple times
    # Training
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, targets = data
        inputs = inputs.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        optimizer.zero_grad()
        outputs = model(inputs, targets)
        loss = sum(loss for loss in outputs.values())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    writer.add_scalar('Loss/train', running_loss/len(train_loader), epoch)
    print(f'Train Loss: {running_loss/len(train_loader)}')

    # Validation
    model.eval()
    running_loss = 0.0
    accuracy = 0
    with torch.no_grad():
        for i, data in enumerate(val_loader, 0):
            inputs, targets = data
            inputs = inputs.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(inputs)
            for j, output in enumerate(outputs):
                for pred_box, pred_label in zip(output['boxes'], output['labels']):
                    for true_box, true_label in zip(targets[j]['boxes'], targets[j]['labels']):
                        if pred_label == true_label and calculate_iou(pred_box, true_box) > 0.5:
                            accuracy += 1

                                                        


                            
    accuracy=accuracy / len(val_data)

    print(f'Accuracy: {accuracy}')

    print(loss, best_val_loss)
    if loss < best_val_loss:
        print("Improvement.")
        best_val_loss = loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!')
            break
# Testing
print('Start Testing')
model.eval()
running_loss = 0.0
accuracy = 0
with torch.no_grad():
    for i, data in enumerate(test_loader, 0):
            inputs, targets = data
            inputs = inputs.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(inputs)

            for j, output in enumerate(outputs):
                for pred_box, pred_label in zip(output['boxes'], output['labels']):
                    for true_box, true_label in zip(targets[j]['boxes'], targets[j]['labels']):
                        if pred_label == true_label and calculate_iou(pred_box, true_box) > 0.5:
                            accuracy += 1
                        

accuracy /= len(test_data)
print(f'Accuracy: {accuracy/len(test_loader)}')

loading annotations into memory...
Done (t=1.41s)
creating index...
index created!
10742
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
loading annotations into memory...
Done (t=0.14s)
creating index...
index created!
Start Training
Train Loss: 0.8508176141025322
Accuracy: 2.2744755244755246
tensor(1.0523, device='cuda:0', grad_fn=<AddBackward0>) inf
Improvement.
Train Loss: 0.8208928845839454
Accuracy: 2.5786713286713288
tensor(0.7991, device='cuda:0', grad_fn=<AddBackward0>) tensor(1.0523, device='cuda:0', grad_fn=<AddBackward0>)
Improvement.
Train Loss: 0.8100025210243859
Accuracy: 2.7263986013986012
tensor(0.8439, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.7991, device='cuda:0', grad_fn=<AddBackward0>)
Train Loss: 0.7977807238720197
Accuracy: 2.88986013986014
tensor(0.4266, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.7991, device='cuda:0', grad_fn=<AddBackward0>)
Improvement.
Train Loss: 0.7883491154013955
Accuracy: 3.05069930069930

In [2]:
import os
print('Saving the model')
# Create the model save folder
save_folder = 'model_save_folder'
os.makedirs(save_folder, exist_ok=True)

# Save the model
model_save_path = os.path.join(save_folder, 'model.pth')
torch.save(model.state_dict(), model_save_path)

print('Model saved successfully at:', model_save_path)


Saving the model
Model saved successfully at: model_save_folder/model.pth


In [3]:
import torch.quantization as quant
print('Quantizing the model')
# Quantize the model
quantized_model = quant.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Save the quantized model
quantized_model_save_path = os.path.join(save_folder, 'quantized_model.pth')
torch.save(quantized_model.state_dict(), quantized_model_save_path)

print('Quantized model saved successfully at:', quantized_model_save_path)


Quantizing the model
Quantized model saved successfully at: model_save_folder/quantized_model.pth


In [13]:
# Set the model to evaluation mode
quantized_model.eval()

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image

# Load an image
image = Image.open('FLIR0004.jpg')
transform = transforms.Compose([
    transforms.ToTensor()  # Convert the image to a PyTorch Tensor
])

# Move the image to the CPU
image = image

# Add the code block below the existing code
# FILEPATH: /home/sjhjrol/Documents/Capstone Modifed/AFV/Flir/flir_training.ipynb
image = transform(image).unsqueeze(0)  # Add an extra dimension for the batch size

# Input the image into the model
with torch.no_grad():
    output = model(image.to(device))

# Display the output
print(output)

import cv2
import numpy as np

# Convert the image back to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

# Get the bounding boxes from the model's output
boxes = output[0]['boxes'].cpu().numpy().astype(np.int32)

# Draw the bounding boxes on the image
for box in boxes:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)

# Display the image
cv2.imshow('Image', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

[{'boxes': tensor([[615.7820,  34.4937, 640.0000,  93.9286],
        [231.7594,  86.6053, 587.2290, 373.6215],
        [620.0148,  74.4984, 640.0000, 145.3932],
        [615.9513,  34.1696, 639.8505,  94.1090],
        [620.1577,  74.2401, 639.8654, 146.0039],
        [599.5904, 383.7763, 640.0000, 438.9997],
        [222.3247,  46.8173, 547.0507, 369.3022],
        [603.6516, 391.9545, 638.0000, 417.4005],
        [599.5909, 262.1013, 640.0000, 327.4353],
        [593.4579, 208.0368, 621.3203, 257.0588],
        [603.6488, 176.8571, 640.0000, 231.3294],
        [576.1081,  24.7783, 640.0000,  83.5713],
        [607.1165, 162.2090, 640.0000, 212.1892],
        [595.8990, 133.9606, 638.1158, 194.7103],
        [609.6840, 392.6681, 618.6058, 411.4550],
        [612.4613, 286.5816, 618.5161, 297.9756],
        [593.6612,  56.3283, 640.0000, 152.6759],
        [612.4352, 267.2441, 618.4233, 278.5969],
        [612.4224, 247.9309, 618.3334, 259.1940],
        [612.4194, 228.6357, 618.2480, 