# Concealed Pistol Bounding Boxes


### Imports

In [None]:
#imports for neural network
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchsummary import summary

#imports for vision
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.ops import distance_box_iou_loss

#imports for preparing dataset
import os
import zipfile
#from google.colab import files
#from google.colab import drive

#imports for visualizations
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

import numpy as np
import cv2

from custom_gun_dataset import CustomDataset

### Preparing Dataset

In [None]:
#applying a transformation to the entire dataset, standardizing it
#reshapes the image to guarantee a 384 size
#grayscales the image
transform = transforms.Compose([
    transforms.Resize((384, 384)), 
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

#relative directory path
dataset_dir = '../Data/CombinedData'
dataset = CustomDataset(root_directory = dataset_dir, transform = transform, categories = ['with gun'])

#initialize train, validation, and test sets
train_size = int(0.8 * len(dataset))              #80% for training
val_size = int(0.1 * len(dataset))                #10% for validation
test_size = len(dataset) - train_size - val_size  #10% (remainder) for test
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

#print sizes of datasets
print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

#set the dataloaders to use the datasets
batch_size = 4
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

#variable to define number of classes
num_classes = 2


### Setting device to GPU 

In [None]:
#setting device to gpu if availible
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

#confirm device
print("Device:", device)


### Visualizing the transformed dataset

In [None]:
#plotting a grid of a single batch
def show_batch(dataLoader):
    for images, class_label, boxes in dataLoader:
        grid_img = make_grid(images, nrow=4, padding=2).permute(1, 2, 0)  # Create image grid
        fig, ax = plt.subplots(figsize=(15, 10))
        
        ax.imshow(grid_img)
        ax.set_xticks([])
        ax.set_yticks([])

        print(boxes)
        # Compute the scale factor to map box coordinates to the grid
        img_width = images.shape[3]  # Image width
        img_height = images.shape[2]  # Image height
        grid_width, grid_height = grid_img.shape[1], grid_img.shape[0]
        
        scale_x = grid_width / (4 * img_width)  # Scale for width
        scale_y = grid_height / ((batch_size // 4) * img_height)  # Scale for height

        # Loop through each image and its bounding boxes
        print(len(boxes), len(boxes[0]))
        for i in range(batch_size):
            row = i // 4  # Compute row index
            col = i % 4  # Compute column index


            x1, y1, x2, y2 = boxes[0][i],boxes[1][i],boxes[2][i],boxes[3][i]
            #print(x1, y1, x2, y2)

            #x *= img_width
            #y *= img_height
            #width *= img_width
            #height *= img_height
            
            # Adjust bounding box coordinates to fit the grid  
            x1 = (x1 ) + (col  * img_width) 
            y1 = (y1 ) + (row  * img_height)
            x2 = (x2 ) + (col  * img_width) 
            y2 = (y2 ) + (row  * img_height)
            width = (x2 - x1)
            height = (y2 - y1)
            
            # Draw the rectangle
            rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='red', facecolor='none')
            ax.add_patch(rect)

        plt.show()
        break  # Show only one batch

show_batch(train_dl)

### Defining the model

In [None]:
#definition for a CNN
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        #model takes input of 384 x 384 x 1
        #make sure in_channels aligns with out_channels from the previous layer

        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.bn4 = nn.BatchNorm2d(64)

        self.conv5 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
        self.bn5 = nn.BatchNorm2d(128)

        self.conv6 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
        self.bn6 = nn.BatchNorm2d(128)

        self.conv7 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
        self.bn7 = nn.BatchNorm2d(256)

        self.conv8 = nn.Conv2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
        self.bn8 = nn.BatchNorm2d(128)

        self.conv9 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
        self.bn9 = nn.BatchNorm2d(128)

        #16 x 16 x 128 comes from sizing, the pool in each layer cut dimensionality in half, 256 is out channels
        self.fc1 = nn.Linear(in_features = (24 * 24 * 128), out_features = 64)
        self.fc2 = nn.Linear(in_features = 64, out_features = 32)
        #self.fc3 = nn.Linear(in_features = 32, out_features = 32)
        #self.fc4 = nn.Linear(in_features = 32, out_features = 32)
        self.fc_class = nn.Linear(in_features = 32, out_features = num_classes)
        self.fc_bbox = nn.Linear(in_features = 32, out_features = 4)

        self.pool = nn.MaxPool2d(2,2)
        self.leaky_relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.5)


    def forward(self, input):
        #forward pass first block, no pool
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.leaky_relu(output)
        #output_pool_1 = self.pool(output)

        #forward pass second block, pool
        output = self.conv2(output)
        output = self.bn2(output)
        output = self.leaky_relu(output)
        output_pool_2 = self.pool(output)

        #forward pass third block, no pool
        output = self.conv3(output_pool_2)
        output = self.bn3(output)
        output = self.leaky_relu(output)
        #output_pool_3 = self.pool(output)

        #skip connection, connect 1-3
        #skip1 = self.pool(input) #downsample
        #if skip1.shape[1] != output.shape[1]:  # if channels differ, adjust
        #    skip1 = nn.functional.pad(skip1, (0, 0, 0, 0, 0, output.shape[1] - skip1.shape[1]))
        #output += skip1
        output_skip_1 = output

        #forward pass fourth block, pool
        output = self.conv4(output_skip_1)
        output = self.bn4(output)
        output = self.leaky_relu(output)
        output_pool_4 = self.pool(output)

        #forward pass fifth block, no pool
        output = self.conv5(output_pool_4)
        output = self.bn5(output)
        output = self.leaky_relu(output)
        #output_pool_5 = self.pool(output)

        #forward pass sixth block, pool
        output = self.conv6(output)
        output = self.bn6(output)
        output = self.leaky_relu(output)
        output_pool_6 = self.pool(output)

        #skip connection, connect 4-6
        skip2 = self.pool(self.pool(output_skip_1)) #downsample
        if skip2.shape[1] != output_pool_6.shape[1]:  # if channels differ, adjust
            skip2 = nn.functional.pad(skip2, (0, 0, 0, 0, 0, output_pool_6.shape[1] - skip2.shape[1]))
        output_pool_6 += skip2
        output_skip_2 = output_pool_6

        #forward pass seventh block, no pool
        output = self.conv7(output_pool_6)
        output = self.bn7(output)
        output = self.leaky_relu(output)

        #forward pass eigth block, pool
        output = self.conv8(output)
        output = self.bn8(output)
        output = self.leaky_relu(output)
        output_pool_8 = self.pool(output)

        #forward pass ninth block, no pool
        output = self.conv9(output_pool_8)
        output = self.bn9(output)
        output = self.leaky_relu(output)

        #skip connection, connect 6-9
        skip3 = self.pool(output_skip_2) # downsample input to match spatial dims
        if skip3.shape[1] != output.shape[1]:  # if channels differ, adjust
            skip3 = nn.functional.pad(skip3, (0,0,0,0, 0, output.shape[1] - skip3.shape[1]))
        output += skip3
        output_skip_3 = output

        #forward pass flattening
        output = output.view(-1, 128 * 24 * 24)

        #forward pass fully connected layers
        output = self.fc1(output)
        output = self.leaky_relu(output)
        output = self.dropout(output)

        output = self.fc2(output)
        output = self.leaky_relu(output)
        output = self.dropout(output)

        #output = self.fc3(output)
        #output = self.leaky_relu(output)
        #output = self.dropout(output)

        #output = self.fc4(output)
        #output = self.leaky_relu(output)
        #output = self.dropout(output)

        output_class = self.fc_class(output)
        
        output_bbox = self.sigmoid(output)
        output_bbox = self.fc_bbox(output_bbox)

        return output_class, output_bbox

model = Network().to(device)

#channels, height, width
summary(model,(1, 384, 384))



### Defining Loss and Optimizer


In [None]:

def bbox_ciou(box1, box2):
    # Intersection
    x1 = torch.min(box1[:, 0], box1[:, 2])
    y1 = torch.min(box1[:, 1], box1[:, 3])
    x2 = torch.max(box1[:, 0], box1[:, 2])
    y2 = torch.max(box1[:, 1], box1[:, 3]) 
    pred = torch.stack([x1, y1, x2, y2], dim=1)

    inter_x1 = torch.max(pred[:, 0], box2[:, 0])
    inter_y1 = torch.max(pred[:, 1], box2[:, 1])
    inter_x2 = torch.min(pred[:, 2], box2[:, 2])
    inter_y2 = torch.min(pred[:, 3], box2[:, 3])

    inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)

    # Areas
    area1 = (pred[:, 2] - pred[:, 0]).clamp(0) * (pred[:, 3] - pred[:, 1]).clamp(0)
    area2 = (box2[:, 2] - box2[:, 0]).clamp(0) * (box2[:, 3] - box2[:, 1]).clamp(0)

    union_area = area1 + area2 - inter_area + 1e-6
    iou = inter_area / union_area

    # Centers
    # b1_center_x = (pred[:, 0] + pred[:, 2]) / 2
    # b1_center_y = (pred[:, 1] + pred[:, 3]) / 2
    # b2_center_x = (box2[:, 0] + box2[:, 2]) / 2
    # b2_center_y = (box2[:, 1] + box2[:, 3]) / 2

    # center_dist_sq = (b1_center_x - b2_center_x) ** 2 + (b1_center_y - b2_center_y) ** 2

    # # Enclosing box
    # enc_x1 = torch.min(pred[:, 0], box2[:, 0])
    # enc_y1 = torch.min(pred[:, 1], box2[:, 1])
    # enc_x2 = torch.max(pred[:, 2], box2[:, 2])
    # enc_y2 = torch.max(pred[:, 3], box2[:, 3])

    # enclosing_diag_sq = (enc_x2 - enc_x1) ** 2 + (enc_y2 - enc_y1) ** 2 + 1e-6

    # # Aspect ratio consistency (v)
    # w1 = (pred[:, 2] - pred[:, 0]).clamp(1e-6)
    # h1 = (pred[:, 3] - pred[:, 1]).clamp(1e-6)
    # w2 = (box2[:, 2] - box2[:, 0]).clamp(1e-6)
    # h2 = (box2[:, 3] - box2[:, 1]).clamp(1e-6)

    # v = (4 / (torch.pi ** 2)) * torch.pow(
    #     torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
    # )

    # with torch.no_grad():
    #     S = 1 - iou
    #     alpha = v / (S + v + 1e-6)

    # ciou = iou - (center_dist_sq / enclosing_diag_sq) - alpha * v
    return iou

class IoULoss(nn.Module):
    def __init__(self):
        super(IoULoss, self).__init__()

    def forward(self, pred, target):
        iou = bbox_ciou(pred, target)
        loss = 1 - iou
        return loss.mean()


In [None]:
#cross entropy loss
criterion_class = nn.CrossEntropyLoss()
criterion_bbox_l1 = nn.SmoothL1Loss()
criterion_bbox_iou = IoULoss()
criterion_bbox_mse = nn.MSELoss()
#criterion_bbox = torchvision.ops.complete_box_iou_loss
weight_class = 0.0
weight_l1 = 0.0
weight_iou = 1.0
weight_mse = 0.0

#adam optimizer
learning_rate = 0.01
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

### Training loop of the model

In [None]:
#number of epochs and early stopping
epochs = 501
early_stopping_patience = 500
early_stopping_counter = 0

#using validation loss as the best model
best_val_loss = float('inf')
best_epoch = 0

#arrays to save each metric during training
train_loss_values = []
train_acc_values = []

val_loss_values = []
val_acc_values = []

#training loop
for epoch in range(epochs):
    #turn on training mode
    model.train()

    #storing loss and accuracy per batch
    train_losses = []
    train_acc = []

    for data, label, bbox in train_dl:
        #moving data to the right device
        data = data.to(device)
        label = label.to(device)
        bbox = np.transpose(bbox)
        bbox = torch.tensor(bbox, dtype = torch.float32, device = device)
        #print(bbox)

        #clear gradients, forward pass, loss, backward and update
        optimizer.zero_grad()
        output_class, output_bbox = model(data)
        loss_class = criterion_class (output_class, label)


        #loss_bbox = criterion_bbox (output_bbox, bbox)
        bbox_mask = (label == 1)
        if bbox_mask.sum() > 0:
            # Select only bbox predictions and targets for positive class samples
            selected_output_bbox = output_bbox[bbox_mask]
            selected_bbox = bbox[bbox_mask]

            loss_bbox_iou = criterion_bbox_iou(selected_output_bbox, selected_bbox)
            loss_bbox_l1 = criterion_bbox_l1(selected_output_bbox, selected_bbox)
        else:
            # No positive samples, so bbox loss is 0
            loss_bbox_iou = torch.tensor(0.0, device=device)
            loss_bbox_l1 = torch.tensor(0.0, device=device)


        loss = (weight_class * loss_class) + (weight_iou * loss_bbox_iou) + (weight_l1 * loss_bbox_l1)
        loss.backward()
        optimizer.step()

        #storing loss and accuracy over the batch
        accuracy = (output_class.argmax(dim = 1) == label).float().mean().item()
        train_losses.append(loss.item())
        train_acc.append(accuracy)

    
    #average loss and acc over the epoch
    current_epoch_loss = sum(train_losses) / len(train_losses)
    current_epoch_acc = sum(train_acc) / len(train_acc)

    #storing current epoch into overall loss and acc
    train_loss_values.append(current_epoch_loss)
    train_acc_values.append(current_epoch_acc)

    #validation testing
    #turn on evaluation mode
    model.eval()

    #storing loss and accuracy 
    val_losses = []
    val_acc = []
    
    #disabling gradient tracking
    with torch.no_grad():
        for data, label, bbox in val_dl:
            #moving data to the right device
            data = data.to(device)
            label = label.to(device)
            bbox = np.transpose(bbox)
            bbox = torch.tensor(bbox, dtype = torch.float32, device = device)
            #print(bbox)

            #running forward pass and calculating loss
            val_class, val_bbox = model(data)
            val_loss_class = criterion_class(val_class, label)


            #val_loss_bbox = criterion_bbox (val_bbox, bbox)
            bbox_mask = (label == 1)
            if bbox_mask.sum() > 0:
                selected_val_bbox = val_bbox[bbox_mask]
                selected_bbox = bbox[bbox_mask]
                #print(selected_val_bbox, selected_bbox)
                val_loss_bbox_iou = criterion_bbox_iou(selected_val_bbox, selected_bbox)
                val_loss_bbox_l1 = criterion_bbox_l1(selected_val_bbox, selected_bbox)
            else:
                val_loss_bbox_iou = torch.tensor(0.0, device=device)
                val_loss_bbox_l1 = torch.tensor(0.0, device=device)


            val_loss =  (weight_class * val_loss_class) + (weight_iou * val_loss_bbox_iou) + (weight_l1 * val_loss_bbox_l1)

            #storing loss and accuracy
            accuracy = (val_class.argmax(dim = 1) == label).float().mean().item()
            val_losses.append(val_loss.item())
            val_acc.append(accuracy)

    #averaging loss and accuracy 
    current_epoch_val_loss = sum(val_losses) / len(val_losses)
    current_epoch_val_acc = sum(val_acc) / len(val_acc)

    #storing current epoch into overall
    val_loss_values.append(current_epoch_val_loss)
    val_acc_values.append(current_epoch_val_acc)

    #checking for model improvement
    if current_epoch_val_loss < best_val_loss:
        #if best epoch, save it
        torch.save(model.state_dict(), "bounding_pistol.pth")
        
        #best epoch information
        best_val_loss = current_epoch_val_loss
        best_epoch = epoch + 1
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1

    #output current epoch information
    if epoch == 0 or epoch == 1 or epoch == 2 or epoch % 50 == 0:
        print(f"Epoch: {epoch+1}")
        print(f"Training Accuracy: {current_epoch_acc:.3f}, Validation Accuracy: {current_epoch_val_acc:.3f}")
        print(f"Training Loss: {current_epoch_loss:.3f}, Validation Loss: {current_epoch_val_loss:.3f}")
        print(f"Current Best Epoch: {best_epoch}\n")

    #stop if not improving
    if early_stopping_counter >= early_stopping_patience:
        print(f"Stopping early after {early_stopping_patience} epochs with no improvement")
        break


### Plotting Model Training Results 

In [None]:
#epoch count
epoch_count = []
for i in range(len(train_loss_values)):
    epoch_count.append(i + 1)

#setting figure size
plt.figure(figsize = (8, 4))

#plotting loss
plt.subplot(1, 2, 1)
plt.title("Loss")
plt.plot(epoch_count, train_loss_values, label='Training Loss')
plt.plot(epoch_count, val_loss_values, label='Validation Loss', linestyle='--')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

#plotting accuracy
plt.subplot(1, 2, 2)
plt.title("Accuracy")
plt.plot(epoch_count, train_acc_values, label='Training Acc')
plt.plot(epoch_count, val_acc_values, label='Validation Acc', linestyle='--')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

#showing figures
plt.tight_layout()
plt.show()

### Testing Model


In [None]:
#loading best saved model
model.load_state_dict(torch.load('bounding_pistol.pth', weights_only = True))

#set to evaluation mode
model.eval() 

#saving images, label, bbox for visualization in next code block
all_images, pred_labels, pred_bbox, true_labels, true_bbox = [], [], [], [], []

#disabling gradient tracking
with torch.no_grad(): 
    total_accuracy = 0.0
    total_test_loss = 0.0
    num_batches = len(test_dl)

    for data, label, bbox in test_dl:
        #loading data and labels to proper device
        data = data.to(device)
        label = label.to(device)
        bbox = np.transpose(bbox)
        bbox = torch.tensor(bbox, dtype = torch.float32, device = device)

        #forward pass, loss, accuracy calculation
        output_class, output_bbox = model(data)
        loss_class = criterion_class(output_class, label)
        loss_bbox = weight_iou * criterion_bbox_iou(output_bbox, bbox)
        loss_bbox += weight_l1 * criterion_bbox_l1(output_bbox, bbox)
        
        loss = (weight_class * loss_class) + (loss_bbox)
        accuracy = (output_class.argmax(dim = 1) == label).float().mean()

        total_test_loss += loss.item()
        total_accuracy += accuracy.item()

        #storing individual data
        for i in range(data.shape[0]):
            all_images.append(data[i].clone().detach().cpu())
            pred_labels.append(output_class[i].argmax(dim=0).item())
            pred_bbox.append(output_bbox[i].clone().detach().cpu().tolist())
            true_labels.append(label[i].item())
            true_bbox.append(bbox[i].clone().detach().cpu().tolist())

    # Calculate the average loss and accuracy over all batches
    avg_loss = total_test_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f"Test Accuracy: {avg_accuracy:.3f}")
    print(f"Test Loss : {avg_loss:.3f}")

In [None]:
#print(len(all_images)) #stored as Tensor
#print(len(pred_labels))
#print(len(true_labels))
#print(len(pred_bbox), len(pred_bbox[0]))
#print(len(true_bbox), len(true_bbox[0]))

def show_image_with_bbox(image, true_label, pred_label, true_bbox, pred_bbox):
    image = image.permute(1, 2, 0).numpy()  # Convert CHW to HWC format
    image = (image * 255).astype(np.uint8)  # Rescale if needed
    # Convert bbox coordinates to pixel values
    height, width, _ = image.shape

    #converting to pixels
    true_bbox = [ int(true_bbox[0] ), int(true_bbox[1] ), int(true_bbox[2] ), int(true_bbox[3] )] 
    print(pred_bbox)
    pred_bbox = [ int(pred_bbox[0] ), int(pred_bbox[1]), int(pred_bbox[2] ), int(pred_bbox[3] )]
    #pred_bbox = [ int(pred_bbox[0] * width ), int(pred_bbox[1] * height), int(pred_bbox[2] * width ), int(pred_bbox[3] *height)]

    #converting first two coords to top left point, instead of middle of the bounding box
    #true_bbox[0] = int(true_bbox[0] - (true_bbox[2] / 2))
    #true_bbox[1] = int(true_bbox[1] - (true_bbox[3] / 2))
    #pred_bbox[0] = int(pred_bbox[0] - (pred_bbox[2] / 2))
    #pred_bbox[0] = int(pred_bbox[1] - (pred_bbox[3] / 2))

    print(pred_bbox)

    if pred_bbox[0] < 0:
        pred_bbox[0] = 0
    if pred_bbox[1] < 0:
        pred_bbox[1] = 0
    if pred_bbox[2] < 0:
        pred_bbox[2] = 0
    if pred_bbox[3] < 0:
        pred_bbox[3] = 0

    if pred_bbox[0] > width:
        pred_bbox[0] = width
    if pred_bbox[1] > height:
        pred_bbox[1] = height
    if pred_bbox[2] > width:
        pred_bbox[2] = width
    if pred_bbox[3] > height:
        pred_bbox[3] = height

    print(pred_bbox)


    # Draw actual bounding box (Green)
    image = cv2.rectangle(image, 
                            (true_bbox[0], true_bbox[1]), 
                            (true_bbox[2], true_bbox[3]), 
                            (0, 255, 0), 2)

    # Draw predicted bounding box (Red)
    image = cv2.rectangle(image, 
                            (pred_bbox[0], pred_bbox[1]), 
                            (pred_bbox[2], pred_bbox[3]), 
                            (255, 0, 0), 2)

    # Display labels
    plt.imshow(image)
    plt.axis("off")
    plt.title(f"True Label: {true_label} | Pred Label: {pred_label}", fontsize=10)
    plt.show()
#
for i in range(len(all_images)):
#for i in range(10):
    if true_labels[i] == 1 or pred_labels[i] == 1 or 1:
        show_image_with_bbox(
            all_images[i],
            true_labels[i],
            pred_labels[i],
            true_bbox[i],
            pred_bbox[i]
        )
    

Testing pipeline process

In [None]:
from PIL import Image
import numpy as np


def thermal_to_image( image_data):
    hi = image_data[:, :, 0].astype(np.uint16)
    lo = image_data[:, :, 1].astype(np.uint16)
    raw_temp = hi * 256 + lo

    #normalize for display (0–255)
    normalized = cv2.normalize(raw_temp, None, 0, 255, cv2.NORM_MINMAX)
    normalized = normalized.astype(np.uint8)

    #apply a color map for visibility
    colored = cv2.applyColorMap(normalized, cv2.COLORMAP_JET)
    return np.array(colored), np.array(normalized)

def convert_raw_thermal(thermal_npy_path):
    npy_file = np.load(thermal_npy_path)
    image_data, thermal_data = np.array_split(npy_file, 2, axis = 1)
    
    #first section for image conversion
    thermal_viewable, thermal_grayscale = thermal_to_image(image_data)
    thermal_image, thermal_grayscale = thermal_viewable, thermal_grayscale
    return thermal_grayscale


def transform_image(test_image):
    #shape and grayscale image
    transform = transforms.Compose([
        transforms.Resize((384, 384)), 
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    #transform, and change to have right dimensionality
    test_image = transform(test_image)
    test_image = test_image.unsqueeze(0) 
    return test_image

def detect_and_bound_pistol(image):
    image = Image.fromarray(image)
    image = transform_image(image)

   
    image = image.to(device)
    detected, pred_bbox = model(image)
    
    print(pred_bbox)


test_image = convert_raw_thermal(r'C:\Users\tungu\OneDrive\Desktop\Capstone\Team-02-Capstone-Project\Embedded System\api\venv\localcache\thermal_frame_appendix.npy')

detect_and_bound_pistol(test_image)