In [1]:
####

In [2]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from torchvision import transforms
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from tqdm.notebook import tqdm
from torchvision.ops import roi_align

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

In [5]:
def show_tensor_images(image_tensor, num_images=2, size=(3 , 800 , 800)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [6]:

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [7]:

class Conv(nn.Module):
    def __init__(self ,
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.conv1 = nn.Conv2d(in_channels ,
                               out_channels ,
                               kernel_size , 
                               stride , 
                               padding)
        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU()
        if self.use_pool:
            self.maxpool = nn.MaxPool2d(kernel_size = (2 , 2) , stride = (2 , 2))
    
    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_pool:
            x = self.maxpool(x)
        return x

In [8]:

class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels, 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels ,
                                         out_channels , 
                                        kernel_size , 
                                        stride , 
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [9]:
class Linear(nn.Module):
    def __init__(self ,  
                 in_channels , 
                 out_channels , 
                 use_norm = False , 
                 use_activation = True):
        super(Linear , self).__init__()

        self.linear1 = nn.Linear(in_channels , 
                                 out_channels)
        self.use_norm = use_norm
        self.use_activation = use_activation

        if self.use_norm:
            self.norm = nn.BatchNorm1d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU()

    def forward(self , x):
        x = self.linear1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [10]:
config = [
          # [out_channels , kernel_size , stride , paddin]
          [64 , 3 , 1 , 1] , 
          [128 , 3 , 1 , 1] , 
          "M" , 
          [128 , 3 , 1 , 1] , 
          [256 , 3 , 1 , 1] , 
          "M" , 
          [256 , 3 , 1 , 1] , 
          [512 , 3 , 1 , 1] , 
          "M" , 
          [512 , 3 , 1 , 1] ,
          [512 , 3 , 1 , 1] ,
          [512 , 3 , 1 , 1] ,
          "M" , 
          [512 , 3 , 1 , 1] ,
          [512 , 3 , 1 , 1] ,
          [512 , 3 , 1 , 1] ,
          #"M" , 
          #4096 
]

In [11]:
class VGG(nn.Module):
    def __init__(self , 
                 in_channels = 3 , 
                 config = config):
        super(VGG , self).__init__()

        self.layers = nn.ModuleList()

        for layer in config:
            if isinstance(layer , list):
                out_channels , kernel_size , stride , padding = layer
                self.layers.append(Conv(
                    in_channels , 
                    out_channels , 
                    kernel_size , 
                    stride , 
                    padding
                ))
                in_channels = out_channels
            elif isinstance(layer , str):
                self.layers.append(nn.MaxPool2d(kernel_size = (2 , 2) , stride = (2 , 2)))
            else:
                if layer == 4096:
                    self.layers.append(nn.Flatten())
                    self.layers.append(Linear(25088 , 4096))
                elif layer == 1000:
                    self.layers.append(Linear(4096 , 1000 , use_activation = False))
                    self.layers.append(nn.Softmax())
    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x

In [12]:
class RPN(nn.Module):
    def __init__(self , 
                 in_channels ,
                 num_anchors = 5 , 
                 feature_map_size = 50 , 
                 num_classes = 2):
        super(RPN , self).__init__()

        self.num_anchors = num_anchors
        self.feature_map_size = feature_map_size

        out_channels_cls = num_classes * self.feature_map_size ** 2  * num_anchors
        out_channels_bbox = 5 * self.num_anchors * self.feature_map_size ** 2 
        hidden_dim = in_channels // 2
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim //2 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv3 = Conv(hidden_dim // 2 , hidden_dim , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv4 = Conv(hidden_dim , in_channels , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)

        self.flatten = nn.Flatten()

        self.linear1 = Linear(4608 , 2048)
        self.linear2 = Linear(2048 , 1024)

        self.linear_cls = Linear(1024 , out_channels_cls)
        self.linear_bbox = Linear(1024 , out_channels_bbox)
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        cls = self.linear_cls(x)
        bbox = self.linear_bbox(x)
        return cls.view(cls.shape[0] , 2 , self.feature_map_size , self.feature_map_size , 5) , bbox.view(bbox.shape[0] , self.num_anchors   , self.feature_map_size , self.feature_map_size , 5)

In [13]:

class Classifier(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_cls = 20 , 
                 out_channels_bbox = 4 , 
                 num_anchors = 5):
        super(Classifier , self).__init__()

        out_channels_bbox = out_channels_bbox * num_anchors * 50 * 50
        out_channels_cls = out_channels_cls * num_anchors * 50 * 50
        self.conv1 = Conv(in_channels , in_channels // 2 , use_pool=True)
        self.conv2 = Conv(in_channels //2 , in_channels // 4 , use_pool = True)
        self.conv3 = Conv(in_channels // 4 , in_channels // 8 , use_pool = False)
        self.conv4 = Conv(in_channels // 8 , in_channels // 16 , use_pool=False)

        self.flatten = nn.Flatten()
        self.linear1 = Linear(32 , 256)
        self.linear2 = Linear(256 , 128)

        self.linear_cls = Linear(128 , out_channels_cls)
        self.linear_bbox = Linear(128 , out_channels_bbox)
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        #print(x.shape)
        x = self.linear1(x)
        x = self.linear2(x)
        cls = self.linear_cls(x)
        bbox = self.linear_bbox(x)
        #print(cls.shape , bbox.shape)
        return cls.view(x.shape[0] ,  5 , 50 , 50 , 20) , bbox.view(x.shape[0] ,  5 , 50 , 50 , 4)

In [None]:
x = torch.randn(2 , 512 , 5 , 5).to(device)
cls = Classifier(512).to(device)
cls_ , bbox = cls(x)
cls_.shape , bbox.shape

In [15]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(self ,
                 img_dir , 
                 label_dir , 
                 csv_file , 
                 anchors , 
                 transforms = None , 
                 S = 50 , 
                 B = 5 , 
                 C = 20):
        super(Dataset_ , self).__init__()

        self.img_dir = img_dir
        self.label_dir = label_dir
        self.df = pd.read_csv(csv_file)
        self.anchors = torch.from_numpy(np.array(anchors))
        #print(self.anchors)
        self.transforms = transforms
        self.number_of_anchors_per_cell = 5
        self.ignore_iou_thresh = 0.5
        self.C = C
        self.S = S
        self.B = B
        self.mask_size = 5

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self , idx):
        img_size = 800
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []
        binary_mask = []
        label_mask = []

        img_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(img_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)
        transform_mask = transforms.Compose([
                                             transforms.ToPILImage() , 
                                             transforms.Resize((5 , 5)) , 
                                             transforms.Grayscale() , 
                                             transforms.ToTensor()
        ])

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])
                '''
                i_ , j_ = int(img_size * y) , int(img_size * x)
                x = img_size * x - j_
                y = img_size * y - i_
                height = img_size * height 
                width = img_size * width 
                img_ = F.crop(image , int(x) , int(y) , int(width) , int(height))
                img_ = transform_mask(img_)
                binary_mask.append(img_)
                label_mask.append(class_label)
                '''
        #label_mask = torch.tensor(label_mask)
        boxes = torch.tensor(boxes) 
        #binary_mask = torch.stack(binary_mask)

        if self.transforms:
            image = self.transforms(image)

        targets = torch.zeros((self.B , self.S , self.S , 5))
        target_mask = torch.zeros((self.B , self.mask_size , self.mask_size , 2))
        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending=True, dim=0)        
            x , y , width , height , class_label = box
            has_anchor = [False for _ in range(self.B)]
            for anchor_idx in anchors_indices:
                anchor_on_scale = anchor_idx % self.B
                S = self.S
                i , j = int(S * y) , int(S * x)
                anchor_taken = targets[anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[anchor_on_scale]:
                    targets[anchor_on_scale , i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinate = torch.tensor([x_cell , y_cell , width_cell , height_cell])
                    targets[anchor_on_scale , i , j , :4] = box_coordinate
                    targets[anchor_on_scale , i , j , 4] = int(class_label)
                    target_mask_ = F.crop(image , int(x_cell) , int(y_cell) , int(width_cell) , int(height_cell))
                    target_mask_ = transform_mask(target_mask_)
                    #print(target_mask_.shape)
                    target_mask[anchor_on_scale] = target_mask_.permute(1 , 2 , 0)
                    #target_mask[anchor_on_scale , i , j , 1] = int(class_label)
                    has_anchor[anchor_on_scale] = True
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[anchor_on_scale , i , j , 0] = -1
        return image , targets , target_mask

In [16]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(self ,
                 img_dir , 
                 label_dir , 
                 csv_file , 
                 anchors , 
                 transforms = None , 
                 S = 50 , 
                 B = 5 , 
                 C = 20):
        super(Dataset_ , self).__init__()

        self.img_dir = img_dir
        self.label_dir = label_dir
        self.df = pd.read_csv(csv_file)
        self.anchors = torch.from_numpy(np.array(anchors))
        #print(self.anchors)
        self.transforms = transforms
        self.number_of_anchors_per_cell = 5
        self.ignore_iou_thresh = 0.5
        self.C = C
        self.S = S
        self.B = B
        self.mask_size = 5

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self , idx):
        img_size = 800
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []
        binary_mask = []
        label_mask = []

        img_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(img_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)
        transform_mask = transforms.Compose([
                                             transforms.ToPILImage() , 
                                             transforms.Resize((5 , 5)) , 
                                             transforms.Grayscale() , 
                                             transforms.ToTensor()
        ])

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])
                '''
                i_ , j_ = int(img_size * y) , int(img_size * x)
                x = img_size * x - j_
                y = img_size * y - i_
                height = img_size * height 
                width = img_size * width 
                img_ = F.crop(image , int(x) , int(y) , int(width) , int(height))
                img_ = transform_mask(img_)
                binary_mask.append(img_)
                label_mask.append(class_label)
                '''
        #label_mask = torch.tensor(label_mask)
        boxes = torch.tensor(boxes) 
        #binary_mask = torch.stack(binary_mask)

        if self.transforms:
            image = self.transforms(image)

        targets = torch.zeros((self.B , self.S , self.S , 5))
        target_mask = torch.zeros((self.B , self.S , self.S , 2))
        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending=True, dim=0)        
            x , y , width , height , class_label = box
            has_anchor = [False for _ in range(self.B)]
            for anchor_idx in anchors_indices:
                anchor_on_scale = anchor_idx % self.B
                S = self.S
                i , j = int(S * y) , int(S * x)
                anchor_taken = targets[anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[anchor_on_scale]:
                    targets[anchor_on_scale , i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinate = torch.tensor([x_cell , y_cell , width_cell , height_cell])
                    targets[anchor_on_scale , i , j , :4] = box_coordinate
                    targets[anchor_on_scale , i , j , 4] = int(class_label)
                    target_mask_ = F.crop(image , int(x_cell) , int(y_cell) , int(width_cell) , int(height_cell))
                    target_mask_ = transform_mask(target_mask_)
                    #print(target_mask_.permute(1 , 2 , 0).shape)
                    #print(target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1].shape)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1] = target_mask_.permute(1 , 2 , 0)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 1:2] = int(class_label)
                    has_anchor[anchor_on_scale] = True
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[anchor_on_scale , i , j , 0] = -1
        return image , targets , target_mask

In [17]:
anchors = [[ 0.28, 0.22], [  0.38, 0.48], [ 0.9, 0.78], [ 0.07, 0.15], [ 0.15, 0.11]]
transform = transforms.Compose([
                                transforms.ToPILImage() , 
                                transforms.Resize((800 , 800)) , 
                                transforms.ToTensor()
])
dataset = Dataset_(
    img_dir = '/content/drive/MyDrive/Yolo_Dataset/images/' , 
    label_dir = '/content/drive/MyDrive/Yolo_Dataset/labels' , 
    csv_file = '/content/drive/MyDrive/Yolo_Dataset/train.csv' , 
    anchors = anchors , 
    transforms = transform
)
dataloader = torch.utils.data.DataLoader(dataset , batch_size = 1 , shuffle=True)

In [None]:
for x , y , z in dataloader:
    show_tensor_images(x)
    print(y.shape)
    print(z.shape)
    break

In [19]:
class ROI(nn.Module):
    def __init__(self , 
                 output_len = 10):
        super(ROI , self).__init__()

        self.crop_size = 5
        self.classifier = Classifier(512)
        self.output_len = output_len

    def forward(self , img , bbox_input , target_boxes):
        '''
            bbox :- [N x 5 x 50 x 50 x 5]
        '''
        roi_imgs = []
        ious = []
        roi_final_imgs = []
        for i in range(bbox_input.shape[1]):
            for j in range(bbox_input.shape[2]):
                for k in range(bbox_input.shape[3]):
                    #print(bbox_input[: , i , j , k , :].shape)
                    croped_img = roi_align(img , bbox_input[: , i , j , k , :] , output_size=(self.crop_size , self.crop_size))
                    roi_imgs.append(croped_img)
                    iou = intersection_over_union(bbox_input[: , i , j , k , :] , target_boxes[: , i , j , k , :])
                    ious.append(iou)
        roi_imgs = torch.stack(roi_imgs).permute(1 , 0 , 2 , 3 , 4)
        ious = torch.stack(ious).permute(1 , 0 , 2)
        anchors_indices = ious.argsort(descending=True, dim=1)
        
        for batch in range(roi_imgs.shape[0]):
            for filter in range(roi_imgs.shape[1]):
                if filter == self.output_len:
                    break
                anchor = anchors_indices[batch , filter , :]
                roi_img = roi_imgs[batch , anchor , : , : , :]
                roi_final_imgs.append(roi_img)
        roi_final_imgs = torch.stack(roi_final_imgs)
        return roi_final_imgs.squeeze(1)

In [None]:
roi = ROI().to(device)
img = torch.randn(2 , 512 , 50 , 50).to(device)
bbox = torch.randn(2 , 5 , 50 , 50 , 5).to(device)
target_boxes = torch.randn(2 , 5 , 50 , 50 , 5).to(device)
roi_imgs = roi(img , bbox , target_boxes)
roi_imgs.shape

In [21]:
class FCN(nn.Module):
    def __init__(self ,
                 in_channels = 512 , 
                 out_channels = 512 ,
                 hidden_dim = 32):
        super(FCN , self).__init__()

        self.conv1 = Conv(in_channels , hidden_dim , use_pool=True)
        self.conv2 = Conv(hidden_dim , hidden_dim * 2 , use_pool=True)
        self.convT1 = ConvT(hidden_dim * 2 , hidden_dim)
        self.convT2 = ConvT(hidden_dim , hidden_dim)
        self.convT3 = ConvT(hidden_dim , hidden_dim)
        self.convT4 = ConvT(hidden_dim , hidden_dim)
        self.convT5 = ConvT(hidden_dim , hidden_dim)
        self.convT6 = ConvT(hidden_dim , out_channels , padding=7)

    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.convT1(x)
        x = self.convT2(x)
        x = self.convT3(x)
        x = self.convT4(x)
        x = self.convT5(x)
        x = self.convT6(x)
        return x

In [None]:
def test():
    x = torch.randn(2 , 512 , 5 , 5).to(device)
    fcn = FCN().to(device)
    z = fcn(x)
    print(z.shape)
test()

In [23]:
class Classifier_(nn.Module):
    def __init__(self , 
                 in_channels = 512 , 
                 out_channels = 12500):
        super(Classifier_ , self).__init__()

        self.conv1 = Conv(in_channels , 32 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv2 = Conv(32 , 64 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv3 = Conv(64 , 128 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv4 = Conv(128 , 256 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.conv5 = Conv(256 , 512 , kernel_size=(2 , 2) , stride=(2 , 2) , padding=0)
        self.linear1 = Linear(512 , 32 , use_norm=False)
        self.linear2 = Linear(32 , out_channels , use_activation=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self , x):  
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.linear1(x.squeeze(2).squeeze(2))
        x = self.sigmoid(self.linear2(x))
        return x.view(x.shape[0] , 5 , 50 , 50 , 1)

In [None]:
x = torch.randn(2 , 1 , 50 , 50).to(device)
cls = Classifier_().to(device)
z = cls(x)
z.shape

In [25]:
class Mask_RCNN(nn.Module):
    def __init__(self , 
                 in_channels = 3):
        super(Mask_RCNN , self).__init__()

        in_channels_vgg = 512
        self.vgg = VGG()
        self.rpn = RPN(in_channels_vgg)
        self.classifier = Classifier(512)
        self.roi = ROI()
        self.fcn = FCN()
        self.to_grayscale = Conv(512 , 1)
        self.classifier_ = Classifier_(in_channels=512)

    def forward(self , x , target_boxes):
        x = self.vgg(x)
        cls , bbox = self.rpn(x)
        croped_imgs = self.roi(x , bbox , target_boxes)
        #print(croped_imgs.shape)
        mask_src = self.fcn(croped_imgs)
        #
        #print(mask_src.shape)
        predictions = self.classifier_(mask_src)
        mask_src = self.to_grayscale(mask_src)
        #predictions = self.classifier_(mask_src)
        cls , bbox = self.classifier(croped_imgs)
        #print(mask_src.shape)
        return cls , bbox , mask_src.view(2 , 5 , 50 , 50 , 1) , predictions

In [None]:
x = torch.randn(1 , 3 , 800 , 800).to(device)
target_boxes = torch.randn(1 , 5 , 50 , 50 , 5).to(device)
mask_rcnn = Mask_RCNN().to(device)
cls , bbox , mask_src , predictions = mask_rcnn(x , target_boxes)
cls.shape , bbox.shape , mask_src.shape , predictions.shape

In [27]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
ce_criterion = nn.CrossEntropyLoss()
lambda_recon = 200
betas = (0.5 , 0.999)


n_epochs = 200
display_steps = 1
lr = 0.002

In [28]:
mask_rcnn = Mask_RCNN().to(device)
opt = torch.optim.Adam(mask_rcnn.parameters() , lr=lr , betas = betas)

In [29]:
def train():
    mean_rcnn_loss = 0
    cur_step = 0
    for epoch in range(n_epochs):
        for img , label ,  mask_label in dataloader:
            img , label , mask_label = img.to(device) , label.to(device)  , mask_label.to(device)

            opt.zero_grad()
            cls_ , bbox_ , mask , mask_pred = mask_rcnn(img , label)
            #print(cls.shape , bbox.shape)
            bbox_loss = recon_criterion(bbox_ , label[...,:4])
            cls_loss = recon_criterion(torch.argmax(cls_ , dim=-1).unsqueeze(-1).float() , label[...,4:5].float())
            mask_loss = recon_criterion(mask , mask_label[... , :1])
            mask_class_loss = recon_criterion(mask_pred , mask_label[... , 1:2])
            #print(bbox_loss , cls_loss , mask_loss , mask_class_loss)

            loss = (bbox_loss + cls_loss + mask_loss + mask_class_loss) / 4
            loss.backward()
            opt.step()

            mean_rcnn_loss += loss.item() / display_steps
            if cur_step % display_steps == 0:
                print(f'Epoch {epoch} , Step {cur_step} , Mean Faster RCNN Loss {mean_rcnn_loss}')
            cur_step +=1
        mean_rcnn_loss = 0

In [None]:
train()