In [40]:
import torch
import numpy as np
import math
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as transforms
import torchvision.datasets.voc as VOC


import matplotlib.pyplot as plt
import matplotlib.patches as patches

# xml library for parsing xml files
from xml.etree import ElementTree as et

#image transforms
#import albumentations as A
#from albumentations.pytorch.transforms import ToTensorV2

#standard libraries
#from engine import train_one_epoch, evaluate
#import utils
#import transforms as T


#our dataset file
#from dataset import FruitDataset

let's see if I have cuda on my desktop:

In [41]:
torch.cuda.is_available()

True

graphics card specifications

In [42]:
!nvidia-smi

Tue Apr 13 06:24:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce GTX 106...  Off  | 00000000:01:00.0  On |                  N/A |
|  0%   38C    P8     6W / 120W |    402MiB /  6075MiB |      5%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [43]:
train_dir = 'data/train_zip/train/'
test_dir = 'data/test_zip/test/'


# Viewing the dataset



Let's make our dataset class before creating our train and test datasets:

In [44]:
import torch
import random
import cv2
import os

class FruitDataset(torch.utils.data.Dataset):
    def __init__(self, files_dir, width, height, transforms=None):
        self.transforms = transforms
        self.files_dir = files_dir
        self.height = height
        self.width = width
        
        #sort images for consistency
        self.imgs = [image for image in sorted(os.listdir(files_dir)) if image[-3:] == 'jpg']
        self.classes = ['_', 'apple', 'banana', 'orange']

    def __getitem__(self, index):
        img_name = self.imgs[index]
        image_path = os.path.join(self.files_dir, img_name)
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
        
        #divide all pixels rgb vals by 255
        img_res /= 255.0

        #annotation file
        annot_filename = img_name[:-4] + '.xml'
        annot_file_path = os.path.join(self.files_dir, annot_filename)

        boxes = []
        labels = []
        tree = et.parse(annot_file_path)

        root = tree.getroot()

        wt = img.shape[1]
        ht = img.shape[0]

        #box coordinates for xml files are extracted
        for member in root.findall('object'):
            labels.append(self.classes.index(member.find('name').text))
            
            #bounding box x coords
            xmin = int(member.find('bndbox').find('xmin').text)
            xmax = int(member.find('bndbox').find('xmax').text)
            
            #bounding box y coords
            ymin = int(member.find('bndbox').find('ymin').text)
            ymax = int(member.find('bndbox').find('ymax').text)

            xmin_corr = (xmin/wt)*self.width
            xmax_corr = (xmax/wt)*self.width
            ymin_corr = (ymin/ht)*self.height
            ymax_corr = (ymax/ht)*self.height

            boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])

        #convert boxes into tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        #areas of the boxes
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd

        #image_id
        image_id = torch.tensor([index])
        target["image_id"] = image_id

        if self.transforms:
            sample = self.transforms(image=image_res, bboxes=target['boxes'], labels=labels)
            img_res = sample['image']
            target['boxes'] = torch.Tensor(sample['bboxes'])

        return img_res, target

    def __len__(self):
        return len(self.imgs) 

Let's create our train dataset and view some files in it:

In [45]:
train_data = FruitDataset(train_dir, 224, 224)
print("here's dataset length: ", len(train_data))

here's dataset length:  240


In [46]:
my_img, my_target = train_data[78]
print("shape of the image: ", img.shape)
print("target bbox", target)

shape of the image:  (224, 224, 3)
target bbox {'boxes': tensor([[ 22.4000,  36.4903, 163.1000,  68.6452],
        [ 24.8500,  39.3806, 163.4500,  94.2968],
        [ 28.0000,  52.3871, 166.9500, 127.8968],
        [ 71.0500,  59.6129, 193.9000, 157.5226]]), 'labels': tensor([2, 2, 2, 2]), 'area': tensor([ 4524.1865,  7611.3750, 10492.0693, 12028.2041]), 'iscrowd': tensor([0, 0, 0, 0]), 'image_id': tensor([78])}


## Visualization

let's see how we can visualize some of our data:

In [None]:
def plot_img_and_bbox(img, target):
    """
    bboxes are typically given to us
    in (x1, y1, x2, y2)
    plot the image with its bbox
    """
    fig, a = plt.subplots(1,1)
    fig.set_size_inches(5,5)
    a.imshow(img)
    for box in target['boxes']:
        x, y = box[0], box[1]
        w = box[2] - box[0]
        h = box[3] - box[1]
        rect = patches.Rectangle((x,y), w, h, linewidth=2, 
                                 edgecolor='r', facecolor='none')
        a.add_patch(rect)
    plt.show()
        
    
    

In [None]:
plot_img_and_bbox(my_img, y_target)