# Install Environment

In [1]:
!pip install timm



In [2]:
!pip install torch torchvision torchaudio timm opencv-python albumentations numpy matplotlib seaborn wandb



In [3]:
!pip install torchmetrics



In [4]:
!pip install torch --extra-index-url https://download.pytorch.org/whl/cu116

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu116


In [5]:
!pip install torchmetrics



# Imports

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import PIL
import torchvision

In [7]:
from timm.models.swin_transformer import SwinTransformer #https://arxiv.org/pdf/2103.14030

In [8]:
import cv2
import numpy as np

In [9]:
from torchmetrics import detection #We're using detection.mean_ap

In [10]:
import sys
import xml.etree.ElementTree as ET

# Data Import & Dataloader

In [None]:
# This section / code block is adapted from Lab 6

#ERRORS TO RESOLVE:
# Different numbers of bounding boxes make the array differently sized (N by 5), N changes
# Resolved by just making each one always N=20 - but I think this is a poor efficiency solution

class Sixray:

    def __init__(self, fname, transform=None):
        self.class_id = {"Gun":0, "Knife":1, "Wrench":2, "Pliers":3, "Scissors":4, "Hammer":5} # Used for converting classes in the XML files to numeric classes
        self.fnames = pd.read_csv(fname, delimiter=',', header=0) # fnames is just used to get the file names in col 0
        self.transform = transform

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        """ Returns a sample of {image, bounding boxes tensor}
        Parameters:
            idx (int): An index
        Returns:
            (dict): With two keys, images which maps to the opened image (as a numpy array), and bounding_boxes which maps to an N by 5 np array containing [class_id, x_center, y_center, width, height]
        """

        fileName = self.fnames.iat[int(idx), 0] # Pulls file name of idx from fnames frame

        image_path = "/Users/edgarsuritis/Downloads/FinalProjectData/JPEGImage/" + fileName + ".jpg" #Full path to image folder on my local machine
        
        image = PIL.Image.open(image_path)

        if fileName[0] == 'P': # If the image has dangerous items in it
            xml_path = "/Users/edgarsuritis/Downloads/FinalProjectData/positive-Annotation/" + fileName + ".xml" #Full path to XML folder on my local machine
            tree = ET.parse(xml_path)
            root = tree.getroot()
            num_boxes = len(root.findall('./object')) #This will overestimate the actual number of bounding boxes if there are empty xml object classes in the xml files (a few have extraneous object tags)

            #bounding_boxes = np.zeros((num_boxes, 5))
            bounding_boxes = np.zeros((20, 5))

            for i, item in enumerate(root.findall('./object')): #for each bounding box

                if item.find('name') is None: # Some random xml's have an extra empty object that should be ignored
                    continue
                
                # get the class, convert to int representation, & store at pos 0 in bounding box arrays
                bounding_boxes[i, 0] = self.class_id[item.find('name').text]

                bndbox = item.find('bndbox')
                
                # get coordates of the sides of the bounding box
                xmin = float(bndbox.find('xmin').text)
                ymin = float(bndbox.find('ymin').text)
                xmax = float(bndbox.find('xmax').text)
                ymax = float(bndbox.find('ymax').text)
                
                # calculate desired representation of the bounding box
                x_center = (xmax+xmin)/2
                y_center = (ymax+ymin)/2
                width = xmax-xmin
                height = ymax-ymin

                # We divide by 416 to normalize the coordinates to between 0 and 1
                # Store the values at pos 1-4 in the bounding box array
                bounding_boxes[i, 1] = x_center/416
                bounding_boxes[i, 2] = y_center/416
                bounding_boxes[i, 3] = width/416
                bounding_boxes[i, 4] = height/416
                
        else:
            #bounding_boxes = np.zeros((0, 5))
            bounding_boxes = np.zeros((20, 5)) # again, just to match the size of the other arrays

        

        sample = {'image': image, 'bounding_boxes': bounding_boxes} # create dictionary representation

        if self.transform:
            sample = self.transform(sample)

        return sample




class Rescale:
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple): Desired output size.
    Hint: Look at Resize from Pytorch
    """

    def __init__(self, output_size):
        self.output_size = output_size
        self.resize = torchvision.transforms.Resize(self.output_size)

    def __call__(self, sample):
        sample['image'] = self.resize.forward(sample['image'])

        return sample


class ToTensor(object):
    """Convert ndarrays in sample to Tensors.
    Credit:  Adapted from https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class
    """

    def __call__(self, sample):
        image, bounding_boxes = sample['image'], sample['bounding_boxes']
        return {'image': torchvision.transforms.functional.pil_to_tensor(image).float(),
                'bounding_boxes': torch.from_numpy(bounding_boxes).long()}


def get_dataloader(version='train', batch_size=32):
    """ Return a dataloader
    """
    if version == 'train':
        
        fname = "/Users/edgarsuritis/Downloads/FinalProjectData/TestTrainSplits/train_test_easy/train.csv" #Local absolute path to training split csv
    elif version == 'test':
        fname = "/Users/edgarsuritis/Downloads/FinalProjectData/TestTrainSplits/train_test_easy/test.csv" #Local absolute path to testing split csv
    dataset = Sixray(fname,
            transform=torchvision.transforms.Compose([Rescale((416, 416)),
                                        ToTensor()]))

    return torch.utils.data.DataLoader(dataset,
                      batch_size=batch_size,
                      num_workers=0,
                      shuffle=True)

In [None]:
dload = get_dataloader('test') #load the testing dataset dataloader

for item in dload: # check that all the items can be loaded correctly, images & XML files
    pass

# WOW this takes 37 minutes on my macbook, just to load through all the data!

# Model