In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt
import os
import numpy as np
import nbimporter

%run Utils.ipynb
%run Config.ipynb

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class Data(Dataset):

    def __init__(self, datadir, labelsdir, test=False):
        
        self.datadir=datadir
        self.img_names=os.listdir(self.datadir)
        self.labelsdir=labelsdir
        self.test=test

    def __len__(self):
        return len(self.img_names)


    def box_transfer_v2(self,coor_lists,rescale_h,rescale_w):
        
        gtboxes = []
        
        for coor_list in coor_lists:
            coors_x=[int(coor_list[2*i]) for i in range(int(len(coor_list)/2))]
            coors_y=[int(coor_list[2*i+1]) for i in range(int(len(coor_list)/2))]
            
            xmin=int(min(coors_x)/rescale_w)
            xmax=int(max(coors_x)/rescale_w)
            ymin=int(min(coors_y)/rescale_h)
            ymax=int(max(coors_y)/rescale_h)

            prev=xmin
            
            for i in range(xmin//16+1, xmax//16+1):
                nxt=16*i-0.5
                gtboxes.append((prev, ymin, nxt, ymax))
                prev=nxt
            gtboxes.append((prev, ymin, xmax, ymax))
        
        return np.array(gtboxes)

    def parse_gtfile(self,gt_path,rescale_h,rescale_w):
        
        coor_lists=list()
        
        with open(gt_path,encoding="utf-8") as f:
            
            content=f.readlines()
            
            for line in content:
                
                if self.test==False:
                    coor_list=line.split(',')[:8]
                else:
                    coor_list=line.split(' ')[:4]
                    
                coor_lists.append(coor_list)
        
        return self.box_transfer_v2(coor_lists,rescale_h,rescale_w)

    def draw_boxes(self,img,cls,base_anchors,gt_box):
        
        for i in range(gt_box.shape[0]):
            
            pt1=(int(gt_box[i][0]),int(gt_box[i][1]))
            pt2=(int(gt_box[i][2]),int(gt_box[i][3]))
            img=cv2.rectangle(img, pt1, pt2, (100, 200, 100),3)

        return img

    def __getitem__(self, idx):
        
        # reading the image according to idx
        
        img_name=self.img_names[idx]
        img_path=os.path.join(self.datadir, img_name)
        img=cv2.imread(img_path)
        
        # if the image does not exist then use a default image i.e. image_1
        # store all the unread images in a separate txt file
        if img is None: 
            with open(r'error_imgs.txt','a',encoding='utf-8') as f:
                f.write('{}\n'.format(img_path))
            
            img_name='img_1.png'
            img_path=os.path.join(self.datadir, img_name)
            img=cv2.imread(img_path)
        
        # generating the bounding boxes and classification probability of each box in each image
        h, w, c=img.shape
        
        rescale_h=float(h)/height
        rescale_w=float(w)/width
        
        h=int(float(h)/rescale_h)
        w=int(float(w)/rescale_w)
        img=cv2.resize(img,(w,h))
        
        # parsing the txt file associated with each image
        gt_path=os.path.join(self.labelsdir, r'gt_'+img_name.split('.')[0]+'.txt')
        gtbox=self.parse_gtfile(gt_path,rescale_h,rescale_w)
        
        [cls, regr, refine], base_anchors=cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox)
        debug_img=self.draw_boxes(img.copy(),cls,base_anchors,gtbox)
        
        
        debug_img=cv2.resize(debug_img,(height,width))

        regr=np.hstack([cls.reshape(cls.shape[0], 1), regr])
        refine=np.hstack([cls.reshape(cls.shape[0], 1), refine.reshape(refine.shape[0], 1)])

        cls=np.expand_dims(cls, axis=0)

        # transforming to torch tensor for feeding into the model
        
        img=torch.from_numpy(img.transpose([2, 0, 1])).float()
        cls=torch.from_numpy(cls).float()
        regr=torch.from_numpy(regr).float()
        refine=torch.from_numpy(refine).float()

        return debug_img, img, cls, regr, refine

In [3]:
batch_size=1
dataset=Data(train_im,train_txt)
lo=DataLoader(dataset,shuffle=True,num_workers=0,batch_size=batch_size)