In [1]:
####

In [None]:
!nvidia-smi

In [3]:
import torch
from torch import nn
from torchsummary import summary
from torchvision.ops import roi_pool

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import transforms
from tqdm.notebook import tqdm

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

In [6]:
def show_tensor_images(image_tensor, num_images=2, size=(3 , 224 , 224)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [7]:

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [8]:
class Conv(nn.Module):
    def __init__(self ,
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.conv1 = nn.Conv2d(in_channels ,
                               out_channels ,
                               kernel_size , 
                               stride , 
                               padding)
        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU()
        if self.use_pool:
            self.maxpool = nn.MaxPool2d(kernel_size = (2 , 2) , stride = (2 , 2))
    
    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_pool:
            x = self.maxpool(x)
        return x

In [None]:
'''x = torch.randn(2 , 3 , 512 , 512).to(device)
conv = Conv(3 , 32).to(device)
z = conv(x)
z.shape'''

In [10]:
class Linear(nn.Module):
    def __init__(self ,  
                 in_channels , 
                 out_channels , 
                 use_norm = False , 
                 use_activation = False):
        super(Linear , self).__init__()

        self.linear1 = nn.Linear(in_channels , 
                                 out_channels)
        self.use_norm = use_norm
        self.use_activation = use_activation

        if self.use_norm:
            self.norm = nn.BatchNorm1d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU()

    def forward(self , x):
        x = self.linear1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [11]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample

        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        

        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [None]:
'''x = torch.randn(2 , 3 , 512 , 512).to(device)
resnet_block = Resnet_Block(3 , 64 , downsample = True).to(device)
z = resnet_block(x)
z.shape'''

In [13]:
class FPN(nn.Module):
    def __init__(self , 
                 in_channels , 
                 hidden_dim = 256 , 
                 out_channels = 128):
        super(FPN , self).__init__()

        self.conv1 = Conv(in_channels , in_channels)
        self.conv2 = Conv(in_channels , out_channels)

    def forward(self , x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        return x2

In [14]:
class Resnet(nn.Module):
    def __init__(self , 
                 in_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 64 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(64 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 8)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 36)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        
        self.fpn1 = FPN(256)
        self.fpn2 = FPN(256)
        self.fpn3 = FPN(256)

        self.conv1_3 = Conv(2048 , 256)
        self.conv1_2 = Conv(1024 , 256)
        self.conv1_1 = Conv(512 , 256)

        self.upsample = nn.Upsample(scale_factor=2)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x = self.conv1(x)
        x = torch.max_pool2d(x , kernel_size = (2 , 2) , stride = (2 , 2))
        x = self.conv2(x)
        c1 = self.conv3(x)
        c2 = self.conv4(c1)
        c3 = self.conv5(c2)
        #print(c1.shape , c2.shape , c3.shape)
        c3 = self.conv1_3(c3)
        c2 = self.conv1_2(c2)
        c1 = self.conv1_1(c1)
        c3_out = self.fpn1(c3)

        c3 = self.upsample(c3)
        c2_out = self.fpn2(c2 + c3)

        c2 = self.upsample(c2)
        c1_out = self.fpn3(c1 + c2)
        return c3_out , c2_out , c1_out

In [None]:
'''x = torch.randn(2 , 3 , 224 , 224).to(device)
resnet = Resnet(3).to(device)
c3 , c2 , c1 = resnet(x)
c1.shape'''

In [16]:
class RPN(nn.Module):
    def __init__(self , 
                 in_channels ,
                 num_anchors = 3 , 
                 feature_map_size = 28 , 
                 num_classes = 2 ):
        super(RPN , self).__init__()

        self.num_anchors = num_anchors
        self.feature_map_size = feature_map_size

        out_channels_cls = num_classes * self.feature_map_size ** 2  * num_anchors
        out_channels_bbox = 5 * self.num_anchors * self.feature_map_size ** 2 
        hidden_dim = in_channels // 2
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim // 2 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv3 = Conv(hidden_dim // 2 , hidden_dim , kernel_size=(3 , 3) , stride=(1 , 1) , padding=1)
        self.conv4 = Conv(hidden_dim , in_channels , kernel_size=(3 , 3) , stride=(1 , 1) , padding=1)

        self.flatten = nn.Flatten()
        mul_dim = (feature_map_size // 4) ** 2
        #mul_dim = mul_dim ** 2
        self.linear1 = Linear(128 * mul_dim , 2048)
        self.linear2 = Linear(2048 , 1024)

        self.linear_cls = Linear(1024 , out_channels_cls)
        self.linear_bbox = Linear(1024 , out_channels_bbox)
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        cls = self.linear_cls(x)
        bbox = self.linear_bbox(x)
        return cls.view(cls.shape[0] , self.num_anchors , self.feature_map_size , self.feature_map_size , 2) , bbox.view(bbox.shape[0] , self.num_anchors   , self.feature_map_size , self.feature_map_size , 5)

In [None]:
'''x = torch.randn(2 , 128 , 7 , 7).to(device)
rpn = RPN(128 , feature_map_size=7).to(device)
cls , bbox = rpn(x)
cls.shape , bbox.shape'''

In [18]:
class Classifier(nn.Module):
    def __init__(self , 
                 in_channels , 
                 feature_map_size , 
                 out_channels_cls = 20 , 
                 out_channels_bbox = 4 , 
                 num_anchors = 3):
        super(Classifier , self).__init__()
    
        self.num_anchors = num_anchors
        self.feature_map_size = feature_map_size

        out_channels_cls = out_channels_cls * self.feature_map_size ** 2  * num_anchors
        out_channels_bbox = 5 * self.num_anchors * self.feature_map_size ** 2 
        hidden_dim = in_channels // 2
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim // 2 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv3 = Conv(hidden_dim // 2 , hidden_dim , kernel_size=(3 , 3) , stride=(1 , 1) , padding=1)
        self.conv4 = Conv(hidden_dim , in_channels , kernel_size=(3 , 3) , stride=(1 , 1) , padding=1)

        self.flatten = nn.Flatten()
        mul_dim = (feature_map_size // 4) ** 2
        #mul_dim = mul_dim ** 2
        self.linear1 = Linear(in_channels * mul_dim , 2048)
        self.linear2 = Linear(2048 , 1024)

        self.linear_cls = Linear(1024 , out_channels_cls)
        self.linear_bbox = Linear(1024 , out_channels_bbox)
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        cls = self.linear_cls(x)
        bbox = self.linear_bbox(x)
        return cls.view(cls.shape[0] , self.num_anchors , self.feature_map_size , self.feature_map_size , 20) , bbox.view(bbox.shape[0] , self.num_anchors   , self.feature_map_size , self.feature_map_size , 5)

In [None]:
'''x = torch.randn(2 , 128 , 28 , 28).to(device)
classifier = Classifier(128 , 28).to(device)
cls , bbox= classifier(x)
print(cls.shape , bbox.shape)'''

In [20]:
class ROI(nn.Module):
    def __init__(self , 
                 output_len = 2 , 
                 feature_map_size = 50):
        super(ROI , self).__init__()

        self.crop_size = feature_map_size
        #self.classifier = Classifier(512 , feature_map_size=feature_map_size)
        self.output_len = output_len

    def forward(self , img , bbox_input , target_boxes):
        '''
            bbox :- [N x 5 x 50 x 50 x 5]
        '''
        roi_imgs = []
        ious = []
        roi_final_imgs = []
        for i in range(bbox_input.shape[1]):
            for j in range(bbox_input.shape[2]):
                for k in range(bbox_input.shape[3]):
                    #print(bbox_input[: , i , j , k , :].shape)
                    croped_img = roi_pool(img , bbox_input[: , i , j , k , :] , output_size=(self.crop_size , self.crop_size))
                    roi_imgs.append(croped_img)
                    iou = intersection_over_union(bbox_input[: , i , j , k , :] , target_boxes[: , i , j , k , :])
                    ious.append(iou)
        roi_imgs = torch.stack(roi_imgs).permute(1 , 0 , 2 , 3 , 4)
        ious = torch.stack(ious).permute(1 , 0 , 2)
        anchors_indices = ious.argsort(descending=True, dim=1)
        
        for batch in range(roi_imgs.shape[0]):
            for filter in range(roi_imgs.shape[1]):
                if filter == self.output_len:
                    break
                anchor = anchors_indices[batch , filter , :]
                roi_img = roi_imgs[batch , anchor , : , : , :]
                roi_final_imgs.append(roi_img)
        roi_final_imgs = torch.stack(roi_final_imgs)
        return roi_final_imgs.squeeze(1)

In [None]:
'''feature_map_size = 7
roi = ROI(feature_map_size=feature_map_size).to(device)
img = torch.randn(2 , 512 , feature_map_size , feature_map_size).to(device)
bbox = torch.randn(2 , 5 , feature_map_size , feature_map_size , 5).to(device)
target_boxes = torch.randn(2 , 5 , feature_map_size , feature_map_size , 5).to(device)
roi_imgs = roi(img , bbox , target_boxes)
roi_imgs.shape'''

In [22]:
class FasterRCNN_FPN(nn.Module):
    def __init__(self):
        super(FasterRCNN_FPN , self).__init__()

        self.resnet = Resnet(3)

        self.rpn_3 = RPN(128 , feature_map_size=28)
        self.rpn_2 = RPN(128 , feature_map_size=14)
        self.rpn_1 = RPN(128 , feature_map_size=7)

        self.classifier_3 = Classifier(128 , feature_map_size=28)
        self.classifier_2 = Classifier(128 , feature_map_size=14)
        #self.classifier_1 = Classifier(128 , feature_map_size=7)

        self.roi_3 = ROI(feature_map_size=28)
        self.roi_2 = ROI(feature_map_size=14)
        #self.roi_1 = ROI(feature_map_size=7)

    def forward(self , x , target_boxes):
        x1 , x2 , x3 = self.resnet(x)
        #print(x1.shape)
        x3_cls , x3_bbox = self.rpn_3(x3)
        x2_cls , x2_bbox = self.rpn_2(x2)
        #x1_cls , x1_bbox = self.rpn_1(x1)
        #print(x1.shape , x1_bbox.shape)
        img_3 = self.roi_3(x3 , x3_bbox , target_boxes[0])
        img_2 = self.roi_2(x2 , x2_bbox , target_boxes[1])
        #img_1 = self.roi_1(x1 , x1_bbox , target_boxes[2])

        out_3 = self.classifier_3(img_3)
        out_2 = self.classifier_2(img_2)
        #out_1 = self.classifier_1(img_1)
        out_3_cls , out_3_bbox = out_3
        out_2_cls , out_2_bbox = out_2
        return out_3_cls , out_3_bbox , out_2_cls , out_2_bbox

In [None]:
'''x = torch.randn(1 , 3 , 224 , 224).to(device)
faster_rcnn_fpn = FasterRCNN_FPN().to(device)
target_boxes = [torch.randn(1 , 3 , 28 , 28 , 5).to(device) , 
                torch.randn(1 , 3 , 14 , 14 , 5).to(device)]
out_3_cls , out_3_bbox , out_2_cls , out_2_bbox = faster_rcnn_fpn(x , target_boxes)
out_3_cls.shape , out_3_bbox.shape , out_2_cls.shape , out_2_bbox.shape'''

In [24]:
anchors = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
]

In [25]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(
        self,
        csv_file,
        img_dir,
        label_dir,
        anchors = anchors,
        S=[14, 28, 56],
        C=20,
        transform=None,
    ):
        self.df = pd.read_csv(csv_file)[:10]
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        self.S = S
        self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 3
        #print(self.num_anchors_per_scale , self.num_anchors)
        self.C = C
        self.ignore_iou_thresh = 0.5

    def __len__(self):
        return len(self.df)

    def __getitem__(self , idx):
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])

        boxes = torch.tensor(boxes) 

        image_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(image_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)

        if self.transform:
            image = self.transform(image)
        
        target = [torch.zeros((self.num_anchors // 3 , S , S , 6)) for S in self.S]

        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending = True , dim = 0)

            x , y , width , height , class_label = box
            has_anchor = [False] * 3
            for anchor_idx in anchors_indices:
                scale_idx = anchor_idx // self.num_anchors_per_scale
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                S = self.S[scale_idx]
                i , j = int(S * y) , int(S * x)
                anchor_taken = target[scale_idx][anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[scale_idx] :
                    target[scale_idx][anchor_on_scale, i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinates = torch.tensor(
                        [x_cell , y_cell , width_cell , height_cell]
                    )
                    target[scale_idx][anchor_on_scale , i , j , 1:5] = box_coordinates
                    target[scale_idx][anchor_on_scale , i , j , 5] = int(class_label)
                    has_anchor[scale_idx] = True

                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    target[scale_idx][anchor_on_scale , i , j , 0] = -1
        return image , tuple(target)

In [26]:
transform = transforms.Compose([
                                transforms.ToPILImage() , 
                                transforms.Resize((224 , 224)) , 
                                transforms.ToTensor()
])
dataset = Dataset_(
    img_dir = '/content/drive/MyDrive/Yolo_Dataset/images/' , 
    label_dir = '/content/drive/MyDrive/Yolo_Dataset/labels' , 
    csv_file = '/content/drive/MyDrive/Yolo_Dataset/train.csv' , 
    transform = transform
)
dataloader = torch.utils.data.DataLoader(dataset , batch_size = 1 , shuffle=True)

In [None]:
for x , y in dataloader:
    show_tensor_images(x)
    print(y[0].shape , y[1].shape , y[2].shape)
    break

In [28]:
adv_criterion = nn.BCEWithLogitsLoss()
l1_criterion = nn.L1Loss()
lambda_recon = 200
betas = (0.5 , 0.999)


n_epochs = 200
display_step = 1
lr = 0.0002
target_shape = 512

In [29]:
faster_rcnn_fpn = FasterRCNN_FPN().to(device)
opt = torch.optim.Adam(faster_rcnn_fpn.parameters() , lr=lr , betas = betas)

In [30]:
def train():
    mean_rcnn_loss = 0
    cur_step = 0
    
    for epoch in range(n_epochs):
        for img , label in dataloader:
            img = img.to(device)
            label_1 , label_2 , _ = label
            label_1 = label_1.to(device)
            label_2 = label_2.to(device)
            target = [label_2[... , 0:5] , 
                      label_1[... , 0:5]]
            opt.zero_grad()
            #print(target[0].shape , target[1].shape)
            out_1_cls , out_1_bbox , out_2_cls , out_2_bbox = faster_rcnn_fpn(img , target)
            #print(out_1_cls.shape , out_1_cls.shape)    
            #print(label_2.shape)
            for cls_1 , bbox_1 , cls_2 , bbox_2 in zip(out_1_cls , out_1_bbox , out_2_cls , out_2_bbox):
                loss_1 = l1_criterion(bbox_1.unsqueeze(0) , label_2[..., :5])
                loss_2 = l1_criterion(bbox_2.unsqueeze(0) , label_1[..., :5])
                #print(cls_1.shape , label_2[... , 5:6].shape)
                class_loss_1 = l1_criterion(cls_1 , label_2[... , 5:6])
                class_loss_2 = l1_criterion(cls_2 , label_1[... , 5:6])
                loss = (loss_1 + loss_2 + class_loss_1 + class_loss_2)
            loss.backward()
            opt.step()

            mean_rcnn_loss += loss.item() / display_steps
            if cur_step % display_steps == 0:
                print(f'Epoch {epoch} , Step {cur_step} , Mean YOLO Loss {mean_rcnn_loss}')
            cur_step +=1
        mean_rcnn_loss = 0

In [None]:
train()