## Data Augementation

In [3]:
import os
import cv2
import glob
import random
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms.v2 as v2
from torchvision import tv_tensors
from PIL import Image
import xml.etree.ElementTree as ET

class Data_Augmentor():
    IMAGE_PATH = "data/kaggle-dataset-433/train/images"
    ANNOTATION_PATH = "data/kaggle-dataset-433/train/annotations"
    def __init__ (self, IMAGE_PATH, ANNOTATION_PATH):
        self.IMAGE_PATH = IMAGE_PATH
        self.ANNOTATION_PATH = ANNOTATION_PATH
    
    def augment(self):
        transforms1 = v2.Compose([
            v2.RandomPerspective(0.65, 1),
            v2.ToDtype(torch.float32, scale=True),
        ])
        transforms2 = v2.Compose([
            v2.ColorJitter(brightness=(0.5,1.5),contrast=(0.5,1.5),saturation=(0.5,1.5),hue=(-0.1,0.1)),
            v2.ToDtype(torch.float32, scale=True),
        ])
        transforms3 = v2.Compose([
            v2.GaussianBlur(15),
            v2.ToDtype(torch.float32, scale=True),
        ])

        self.new_data = []
        self.img_list = os.listdir(self.IMAGE_PATH)

        print("Augmenting Data:")
        for img in tqdm(self.img_list):
            img_path = os.path.join(self.IMAGE_PATH, img) #get image file path so we can load it with opencv
            annotation_path = os.path.join(self.ANNOTATION_PATH, img.replace('.png', '.xml')) # get required image annotations
            
            img = Image.open(img_path)
            width, height = img.size
            
            # Parse the XML annotation file to extract bounding box coordinates
            root = ET.parse(annotation_path).getroot()
            
            # Iterate through the XML and extract bounding box coordinates
            for obj in root.findall('.//object'):
                bndbox = obj.find('bndbox')
                xmin = int(bndbox.find('xmin').text)
                ymin = int(bndbox.find('ymin').text)
                xmax = int(bndbox.find('xmax').text)
                ymax = int(bndbox.find('ymax').text)
            boxes = xmin, ymin, xmax, ymax
            boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(height, width))
            self.new_data.append([np.asarray(img), boxes]) # add the existing image without modifications
            
            img1, boxes1 = transforms1(img, boxes) # transform image
            self.new_data.append([np.asarray(img1), boxes1])
            
            img2, boxes2 = transforms2(img, boxes) # transform image
            self.new_data.append([np.asarray(img2), boxes2])

            img3, boxes3 = transforms3(img, boxes) # transform image
            self.new_data.append([np.asarray(img3), boxes3])

        print(f"New Training Data Size: {len(self.new_data)}")        

data_aug = Data_Augmentor(IMAGE_PATH="data/kaggle-dataset-433/train/images", ANNOTATION_PATH="data/kaggle-dataset-433/train/annotations")      
data_aug.augment()  

Augmenting Data:


100%|██████████| 433/433 [00:22<00:00, 19.44it/s]

New Training Data Size: 1732





In [21]:
# preview first few images so we can make sure our data was processed correctly
for i in range(0, 12):
    img = data_aug.new_data[i][0]
    x,y,x1,y1 = data_aug.new_data[i][1][0]
    cv2.rectangle(img, (int(x), int(y)), (int(x1), int(y1)), (255, 255, 255), 2)
    cv2.imshow(f"{i}", img)
cv2.waitKey(0)
cv2.destroyAllWindows()

## Training Dataset Preprocessing

In [21]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from matplotlib import cm
import xml.etree.ElementTree as ET

class LPR_Training_Dataset_Processed():
    IMAGE_PATH = "data/kaggle-dataset-433/train/images-processed"
    ANNOTATION_PATH = "data/kaggle-dataset-433/train/annotations"
    TARGET_IMAGE_SIZE = 224

    training_data = []

    def create_training_data(self):
        #self.img_list = os.listdir(self.IMAGE_PATH)
        self.img_list = data_aug.new_data
        for img_data in tqdm(self.img_list):
           # img_path = os.path.join(self.IMAGE_PATH, img) #get image file path so we can load it with opencv
           # annotation_path = os.path.join(self.ANNOTATION_PATH, img.replace('.png', '.xml')) # get required image annotations

           # img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) # read image as grayscale
           
            #img = data_aug.new_data[i][0]
            img = Image.fromarray(np.uint8(img_data[0])).convert('L')
            og_img_width, og_img_height = img.size # store original shape of image so we can resize boudning box later
            img = cv2.resize(np.asarray(img), (self.TARGET_IMAGE_SIZE, self.TARGET_IMAGE_SIZE)) # resize image so they're all the same width and height

            # Parse the XML annotation file to extract bounding box coordinates
        #    root = ET.parse(annotation_path).getroot()
            
            # Iterate through the XML and extract bounding box coordinates
        #    for obj in root.findall('.//object'):
        #        bndbox = obj.find('bndbox')
        #        xmin = int(bndbox.find('xmin').text)
        #        ymin = int(bndbox.find('ymin').text)
        #        xmax = int(bndbox.find('xmax').text)
        #        ymax = int(bndbox.find('ymax').text)
            
            # calculate new scale ratio
            x_scale = self.TARGET_IMAGE_SIZE / og_img_width 
            y_scale = self.TARGET_IMAGE_SIZE / og_img_height
            x,y,x1,y1 = img_data[1][0]
            bounding_box_coordinates = (x * x_scale, y * y_scale, x1 * x_scale, y1 * y_scale) # resize bounding box to fit resized image
            
            self.training_data.append([np.array(img), bounding_box_coordinates])

        np.random.shuffle(self.training_data)

training_dataset = LPR_Training_Dataset_Processed()
training_dataset.create_training_data()

  4%|▍         | 65/1732 [00:00<00:02, 644.46it/s]

(tensor(101.2480), tensor(104.4776), tensor(187.7120), tensor(144.5970))
(tensor(136.1920), tensor(114.5075), tensor(190.8480), tensor(139.5821))
(tensor(101.2480), tensor(104.4776), tensor(187.7120), tensor(144.5970))
(tensor(101.2480), tensor(104.4776), tensor(187.7120), tensor(144.5970))
(tensor(75.0400), tensor(115.6129), tensor(146.7200), tensor(144.5161))
(tensor(81.2000), tensor(118.3226), tensor(132.1600), tensor(145.4194))
(tensor(75.0400), tensor(115.6129), tensor(146.7200), tensor(144.5161))
(tensor(75.0400), tensor(115.6129), tensor(146.7200), tensor(144.5161))
(tensor(78.4000), tensor(4.9778), tensor(169.6800), tensor(147.3422))
(tensor(90.7200), tensor(7.9644), tensor(166.3200), tensor(118.4711))
(tensor(78.4000), tensor(4.9778), tensor(169.6800), tensor(147.3422))
(tensor(78.4000), tensor(4.9778), tensor(169.6800), tensor(147.3422))
(tensor(98.), tensor(95.6404), tensor(119.8400), tensor(109.9026))
(tensor(105.2800), tensor(93.1236), tensor(120.9600), tensor(105.7079))
(

 12%|█▏        | 206/1732 [00:00<00:02, 691.01it/s]

(tensor(57.6800), tensor(73.9200), tensor(164.0800), tensor(150.0800))
(tensor(201.6000), tensor(139.2659), tensor(214.4800), tensor(151.8502))
(tensor(144.4800), tensor(111.5805), tensor(156.8000), tensor(120.8090))
(tensor(201.6000), tensor(139.2659), tensor(214.4800), tensor(151.8502))
(tensor(201.6000), tensor(139.2659), tensor(214.4800), tensor(151.8502))
(tensor(62.7200), tensor(64.7111), tensor(171.3600), tensor(172.2311))
(tensor(49.8400), tensor(92.5867), tensor(140.5600), tensor(164.2667))
(tensor(62.7200), tensor(64.7111), tensor(171.3600), tensor(172.2311))
(tensor(62.7200), tensor(64.7111), tensor(171.3600), tensor(172.2311))
(tensor(86.2189), tensor(119.8794), tensor(119.1849), tensor(150.2713))
(tensor(78.1887), tensor(123.2563), tensor(103.9698), tensor(144.0804))
(tensor(86.2189), tensor(119.8794), tensor(119.1849), tensor(150.2713))
(tensor(86.2189), tensor(119.8794), tensor(119.1849), tensor(150.2713))
(tensor(106.9600), tensor(122.8657), tensor(135.5200), tensor(141

 25%|██▍       | 430/1732 [00:00<00:01, 734.41it/s]

(tensor(93.6320), tensor(102.2933), tensor(125.8880), tensor(129.9200))
(tensor(85.1200), tensor(103.7867), tensor(107.5200), tensor(123.9467))
(tensor(93.6320), tensor(102.2933), tensor(125.8880), tensor(129.9200))
(tensor(93.6320), tensor(102.2933), tensor(125.8880), tensor(129.9200))
(tensor(137.8462), tensor(163.3814), tensor(179.3767), tensor(179.5464))
(tensor(129.8935), tensor(144.3299), tensor(167.8895), tensor(168.))
(tensor(137.8462), tensor(163.3814), tensor(179.3767), tensor(179.5464))
(tensor(137.8462), tensor(163.3814), tensor(179.3767), tensor(179.5464))
(tensor(80.4000), tensor(119.4667), tensor(107.2000), tensor(151.4667))
(tensor(120.), tensor(108.8000), tensor(146.), tensor(128.7111))
(tensor(80.4000), tensor(119.4667), tensor(107.2000), tensor(151.4667))
(tensor(80.4000), tensor(119.4667), tensor(107.2000), tensor(151.4667))
(tensor(83.5254), tensor(110.7200), tensor(136.6780), tensor(128.))
(tensor(82.3864), tensor(147.2000), tensor(121.4915), tensor(158.7200))
(te

 34%|███▍      | 585/1732 [00:00<00:01, 749.75it/s]

(tensor(172.3077), tensor(150.9124), tensor(201.4086), tensor(165.1239))
(tensor(172.3077), tensor(150.9124), tensor(201.4086), tensor(165.1239))
(tensor(58.8000), tensor(110.7027), tensor(72.8000), tensor(121.0811))
(tensor(54.3200), tensor(136.6487), tensor(64.4000), tensor(143.5676))
(tensor(58.8000), tensor(110.7027), tensor(72.8000), tensor(121.0811))
(tensor(58.8000), tensor(110.7027), tensor(72.8000), tensor(121.0811))
(tensor(84.2847), tensor(112.6400), tensor(136.2983), tensor(128.6400))
(tensor(108.2034), tensor(72.9600), tensor(138.9559), tensor(89.6000))
(tensor(84.2847), tensor(112.6400), tensor(136.2983), tensor(128.6400))
(tensor(84.2847), tensor(112.6400), tensor(136.2983), tensor(128.6400))
(tensor(84.1267), tensor(151.3513), tensor(138.3530), tensor(165.4775))
(tensor(104.3982), tensor(131.8438), tensor(144.9412), tensor(145.2973))
(tensor(84.1267), tensor(151.3513), tensor(138.3530), tensor(165.4775))
(tensor(84.1267), tensor(151.3513), tensor(138.3530), tensor(165.4

 38%|███▊      | 660/1732 [00:00<00:01, 721.49it/s]

(tensor(58.8000), tensor(134.1003), tensor(85.1200), tensor(155.8261))
(tensor(80.0800), tensor(128.8562), tensor(119.2800), tensor(158.0736))
(tensor(80.0800), tensor(128.8562), tensor(119.2800), tensor(158.0736))
(tensor(84.2847), tensor(110.7200), tensor(137.0576), tensor(129.2800))
(tensor(84.2847), tensor(108.1600), tensor(130.2237), tensor(117.7600))
(tensor(84.2847), tensor(110.7200), tensor(137.0576), tensor(129.2800))
(tensor(84.2847), tensor(110.7200), tensor(137.0576), tensor(129.2800))
(tensor(35.8400), tensor(134.2322), tensor(74.4800), tensor(151.0112))
(tensor(59.3600), tensor(123.3258), tensor(90.7200), tensor(139.2659))
(tensor(35.8400), tensor(134.2322), tensor(74.4800), tensor(151.0112))
(tensor(35.8400), tensor(134.2322), tensor(74.4800), tensor(151.0112))
(tensor(98.), tensor(129.9200), tensor(127.6800), tensor(150.0800))
(tensor(80.0800), tensor(129.9200), tensor(102.4800), tensor(147.8400))
(tensor(98.), tensor(129.9200), tensor(127.6800), tensor(150.0800))
(tens

 47%|████▋     | 818/1732 [00:01<00:01, 752.58it/s]

(tensor(0.4667), tensor(46.0444), tensor(216.5333), tensor(95.8222))
(tensor(32.6667), tensor(92.0889), tensor(189.9333), tensor(130.6667))
(tensor(0.4667), tensor(46.0444), tensor(216.5333), tensor(95.8222))
(tensor(0.4667), tensor(46.0444), tensor(216.5333), tensor(95.8222))
(tensor(136.0800), tensor(136.9302), tensor(159.6000), tensor(152.5581))
(tensor(133.2800), tensor(96.), tensor(152.3200), tensor(108.6512))
(tensor(136.0800), tensor(136.9302), tensor(159.6000), tensor(152.5581))
(tensor(136.0800), tensor(136.9302), tensor(159.6000), tensor(152.5581))
(tensor(72.2400), tensor(199.4521), tensor(129.9200), tensor(222.4658))
(tensor(107.5200), tensor(197.1507), tensor(141.1200), tensor(218.6301))
(tensor(72.2400), tensor(199.4521), tensor(129.9200), tensor(222.4658))
(tensor(72.2400), tensor(199.4521), tensor(129.9200), tensor(222.4658))
(tensor(124.3200), tensor(168.), tensor(165.2000), tensor(188.7200))
(tensor(138.3200), tensor(158.4800), tensor(173.0400), tensor(175.8400))
(ten

 56%|█████▋    | 976/1732 [00:01<00:01, 753.53it/s]

(tensor(92.9600), tensor(123.7015), tensor(130.4800), tensor(153.7910))
(tensor(57.6800), tensor(114.5075), tensor(87.3600), tensor(135.4030))
(tensor(92.9600), tensor(123.7015), tensor(130.4800), tensor(153.7910))
(tensor(92.9600), tensor(123.7015), tensor(130.4800), tensor(153.7910))
(tensor(56.5600), tensor(101.5467), tensor(151.2000), tensor(131.4133))
(tensor(89.6000), tensor(67.2000), tensor(145.0400), tensor(99.3067))
(tensor(56.5600), tensor(101.5467), tensor(151.2000), tensor(131.4133))
(tensor(56.5600), tensor(101.5467), tensor(151.2000), tensor(131.4133))
(tensor(79.5200), tensor(127.4311), tensor(146.1600), tensor(156.3022))
(tensor(67.2000), tensor(102.5422), tensor(117.6000), tensor(118.4711))
(tensor(79.5200), tensor(127.4311), tensor(146.1600), tensor(156.3022))
(tensor(79.5200), tensor(127.4311), tensor(146.1600), tensor(156.3022))
(tensor(86.8000), tensor(157.7333), tensor(112.), tensor(169.8667))
(tensor(105.2800), tensor(146.5333), tensor(123.7600), tensor(157.7333)

 66%|██████▌   | 1141/1732 [00:01<00:00, 783.50it/s]

(tensor(80.0800), tensor(125.4400), tensor(145.6000), tensor(158.2933))
(tensor(80.0800), tensor(125.4400), tensor(145.6000), tensor(158.2933))
(tensor(65.5200), tensor(131.7153), tensor(151.2000), tensor(160.2397))
(tensor(85.6800), tensor(104.0300), tensor(143.9200), tensor(135.0712))
(tensor(65.5200), tensor(131.7153), tensor(151.2000), tensor(160.2397))
(tensor(65.5200), tensor(131.7153), tensor(151.2000), tensor(160.2397))
(tensor(86.8000), tensor(160.3137), tensor(118.1600), tensor(177.8824))
(tensor(74.4800), tensor(131.7647), tensor(100.2400), tensor(139.4510))
(tensor(86.8000), tensor(160.3137), tensor(118.1600), tensor(177.8824))
(tensor(86.8000), tensor(160.3137), tensor(118.1600), tensor(177.8824))
(tensor(87.5789), tensor(91.2800), tensor(124.0702), tensor(110.8800))
(tensor(78.5965), tensor(113.1200), tensor(111.7193), tensor(128.2400))
(tensor(87.5789), tensor(91.2800), tensor(124.0702), tensor(110.8800))
(tensor(87.5789), tensor(91.2800), tensor(124.0702), tensor(110.88

 75%|███████▍  | 1297/1732 [00:01<00:00, 748.97it/s]

(tensor(101.2480), tensor(124.3200), tensor(138.8800), tensor(136.6400))
(tensor(78.8480), tensor(141.6800), tensor(145.6000), tensor(156.8000))
(tensor(78.8480), tensor(141.6800), tensor(145.6000), tensor(156.8000))
(tensor(81.7600), tensor(95.2836), tensor(137.7600), tensor(136.2388))
(tensor(103.0400), tensor(87.7612), tensor(141.1200), tensor(120.3582))
(tensor(81.7600), tensor(95.2836), tensor(137.7600), tensor(136.2388))
(tensor(81.7600), tensor(95.2836), tensor(137.7600), tensor(136.2388))
(tensor(146.7586), tensor(192.3310), tensor(218.8506), tensor(213.1862))
(tensor(133.3701), tensor(155.2552), tensor(188.9839), tensor(186.1517))
(tensor(146.7586), tensor(192.3310), tensor(218.8506), tensor(213.1862))
(tensor(146.7586), tensor(192.3310), tensor(218.8506), tensor(213.1862))
(tensor(175.2800), tensor(159.1579), tensor(207.7600), tensor(176.8421))
(tensor(146.1600), tensor(160.3368), tensor(164.6400), tensor(172.1263))
(tensor(175.2800), tensor(159.1579), tensor(207.7600), tenso

 84%|████████▍ | 1451/1732 [00:01<00:00, 754.86it/s]

(tensor(179.2000), tensor(136.8889), tensor(218.0267), tensor(150.1630))
(tensor(180.1956), tensor(153.4815), tensor(208.0711), tensor(166.7556))
(tensor(179.2000), tensor(136.8889), tensor(218.0267), tensor(150.1630))
(tensor(179.2000), tensor(136.8889), tensor(218.0267), tensor(150.1630))
(tensor(12.8800), tensor(121.1000), tensor(30.2400), tensor(134.4000))
(tensor(64.4000), tensor(117.6000), tensor(87.9200), tensor(130.9000))
(tensor(12.8800), tensor(121.1000), tensor(30.2400), tensor(134.4000))
(tensor(12.8800), tensor(121.1000), tensor(30.2400), tensor(134.4000))
(tensor(98.0757), tensor(116.4800), tensor(142.2703), tensor(136.6400))
(tensor(82.3351), tensor(129.9200), tensor(113.2108), tensor(151.2000))
(tensor(98.0757), tensor(116.4800), tensor(142.2703), tensor(136.6400))
(tensor(98.0757), tensor(116.4800), tensor(142.2703), tensor(136.6400))
(tensor(92.1229), tensor(140.0846), tensor(114.2935), tensor(157.0030))
(tensor(115.4403), tensor(134.6707), tensor(131.4949), tensor(15

 93%|█████████▎| 1608/1732 [00:02<00:00, 768.41it/s]

(tensor(53.2000), tensor(116.4800), tensor(160.1600), tensor(154.3111))
(tensor(53.2000), tensor(116.4800), tensor(160.1600), tensor(154.3111))
(tensor(150.0800), tensor(151.3992), tensor(175.8400), tensor(168.2213))
(tensor(130.4800), tensor(160.2530), tensor(144.4800), tensor(174.4190))
(tensor(150.0800), tensor(151.3992), tensor(175.8400), tensor(168.2213))
(tensor(150.0800), tensor(151.3992), tensor(175.8400), tensor(168.2213))
(tensor(126.5600), tensor(165.5322), tensor(135.5200), tensor(182.2373))
(tensor(130.4800), tensor(176.1627), tensor(137.7600), tensor(189.0712))
(tensor(126.5600), tensor(165.5322), tensor(135.5200), tensor(182.2373))
(tensor(126.5600), tensor(165.5322), tensor(135.5200), tensor(182.2373))
(tensor(44.0407), tensor(35.2000), tensor(199.3220), tensor(167.6800))
(tensor(59.2271), tensor(42.2400), tensor(145.0305), tensor(152.9600))
(tensor(44.0407), tensor(35.2000), tensor(199.3220), tensor(167.6800))
(tensor(44.0407), tensor(35.2000), tensor(199.3220), tensor

100%|██████████| 1732/1732 [00:02<00:00, 744.88it/s]

(tensor(193.2000), tensor(129.7655), tensor(204.4000), tensor(136.7172))
(tensor(160.1600), tensor(111.2276), tensor(167.4400), tensor(117.4069))
(tensor(193.2000), tensor(129.7655), tensor(204.4000), tensor(136.7172))
(tensor(193.2000), tensor(129.7655), tensor(204.4000), tensor(136.7172))
(tensor(94.0800), tensor(110.5067), tensor(135.5200), tensor(127.6800))
(tensor(99.1200), tensor(96.3200), tensor(122.6400), tensor(112.7467))
(tensor(94.0800), tensor(110.5067), tensor(135.5200), tensor(127.6800))
(tensor(94.0800), tensor(110.5067), tensor(135.5200), tensor(127.6800))
(tensor(127.1200), tensor(189.8015), tensor(141.6800), tensor(196.6412))
(tensor(136.0800), tensor(159.8779), tensor(143.9200), tensor(164.1527))
(tensor(127.1200), tensor(189.8015), tensor(141.6800), tensor(196.6412))
(tensor(127.1200), tensor(189.8015), tensor(141.6800), tensor(196.6412))
(tensor(83.1131), tensor(150.6787), tensor(137.8462), tensor(166.1502))
(tensor(110.4796), tensor(115.0270), tensor(147.4751), te




In [13]:
# preview first few images so we can make sure our data was processed correctly
for i in range(0, 10):
    img = training_dataset.training_data[i][0]
    x,y,x1,y1 = training_dataset.training_data[i][1]
    cv2.rectangle(img, (int(x), int(y)), (int(x1), int(y1)), (255, 255, 255), 2)
    cv2.imshow(f"{i}", img)
cv2.waitKey(0)
cv2.destroyAllWindows()

## Model Definition

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LPR_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)

        x = torch.randn(224, 224).view(-1, 1, 224, 224)
        self._to_linear = None
        self.convs(x)

        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 4)
    
    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))

        if self._to_linear is None: # used to flatten it since pytorch doesn't have tensorflow's flatten function
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]

        return x
    
    def forward(self, x):
        x = self.convs(x) # pass through all convulutional layers
        x = x.view(-1, self._to_linear) # flatten it
        x = F.relu(self.fc1(x)) # pass through fully connected (dense) layer
        x = self.fc2(x)
        # return F.softmax(x, dim = 1) # renable this when we move to the gpu
        return x
    
net = LPR_Net()

## Get data and split between test and training data

In [62]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function = nn.MSELoss()
X = torch.Tensor([i[0] for i in training_dataset.training_data]).view(-1, 224, 224) # image values
X = X / 255.0
y = torch.Tensor([i[1] for i in training_dataset.training_data]) # bounding box values

VAL_PCT = 0.1 # percent of data we want to use for testing vs training
val_size = int(len(X) * VAL_PCT)

# create test and training splits
train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

print(len(train_X))
print(len(test_X))

1473
259


## Train!

In [63]:
BATCH_SIZE = 250 # reduce if memory errors
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i + BATCH_SIZE].view(-1, 1, 224, 224)
        batch_y = train_y[i:i + BATCH_SIZE]
        optimizer.zero_grad()
        outputs = net(batch_X)

        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()

print(loss)


100%|██████████| 6/6 [01:11<00:00, 11.87s/it]

tensor(5431.8027, grad_fn=<MseLossBackward0>)





In [64]:
correct = 0
total = 0

ACCEPTABLE_DISTANCE = 50

def close_enough(num1, num2):
    return (abs(num1 - num2) < ACCEPTABLE_DISTANCE)

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
       # real_bbox = torch.argmax(test_y[i])
        real_bbox = test_y[i]
        net_out = net(test_X[i].view(-1, 1, 224, 224))[0]
        #print(net_out)
        #predicted_bbox = torch.argmax(net_out)
        predicted_bbox = net_out
        #print(predicted_bbox[0], real_bbox[0])
        if close_enough(predicted_bbox[0], real_bbox[0]) and close_enough(predicted_bbox[1], real_bbox[1]) and close_enough(predicted_bbox[2], real_bbox[2]) and close_enough(predicted_bbox[3], real_bbox[3]):
            correct+= 1
        total += 1
        #print(real_bbox, net_out)



print("Accuracy:", round((correct / total) * 100, 3), "%")

100%|██████████| 259/259 [00:06<00:00, 42.14it/s]

Accuracy: 3.475 %



