In [1]:
####

In [2]:
import torch
from torch import nn
import torchvision.transforms.functional as F
from torchvision.ops import nms , batched_nms

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from torchvision import transforms
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from tqdm.notebook import tqdm
from torchvision.ops import roi_align

In [4]:
def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

In [5]:
def show_tensor_images(image_tensor, num_images=2, size=(3 , 550 , 550)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [6]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:

class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_dropout = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_dropout = use_dropout
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU(inplace=False)
        if self.use_dropout:
            self.dropout = nn.Dropout()

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x

In [9]:

class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride ,
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [10]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample

        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        

        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [11]:

class Linear(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Linear , self).__init__()
        self.linear1 = nn.Linear(in_channels , out_channels)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self , x):
        x = self.linear1(x)
        x = self.softmax(x)
        return x

In [12]:
class Resnet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 64 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(64 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 4)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 6)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        self.linear = Linear(2048 , out_channels)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x_out = []
        x = self.conv1(x)
        x = torch.max_pool2d(x , kernel_size = (2 , 2) , stride = (2 , 2))
        x = self.conv2(x)
        x = self.conv3(x)
        x_out.append(x)
        x = self.conv4(x)
        x_out.append(x)
        x = self.conv5(x)
        x_out.append(x)
        return x_out

In [None]:
'''def test():
    resnet = Resnet(3 , 1000).to(device)
    x = torch.randn(2 , 3 , 550 , 550).to(device)
    z = resnet(x)
    for z_ in z:
        print(z_.shape)
test()'''

In [14]:
class Concat_Block(nn.Module):
    def __init__(self):
        super(Concat_Block , self).__init__()
        b = 5
    def forward(self , x):
        return x

In [15]:
config_top = [
              256 , 
              256 , 
              256
]

config_down = [
               'C' , 
               1024 , 
               'C' , 
               256 , 
]

In [16]:
class FPN(nn.Module):
    def __init__(self , 
                 config_top=config_top , 
                 config_down=config_down , 
                 in_channels_ = 2048):
        super(FPN , self).__init__()

        self.top_layers = nn.ModuleList()
        self.down_layers = nn.ModuleList()

        in_channels = in_channels_
        for layer in config_top:
            out_channels = layer
            self.top_layers.append(
                Conv(in_channels , out_channels , kernel_size=(3 , 3) , stride=(1 , 1) , padding = 1)
            )
            in_channels = out_channels
        in_channels = in_channels_
        for i , layer in enumerate(config_down):
            if isinstance(layer , str):
                self.down_layers.append(Concat_Block())
            else:
                
                if i == 0:
                    out_channels = layer
                    self.down_layers.append(
                        ConvT(in_channels , out_channels)
                    )
                    in_channels = out_channels
                else :
                    out_channels = layer
                    self.down_layers.append(
                        ConvT(in_channels * 2  , out_channels)
                    )
                    in_channels = out_channels

    def forward(self , x_list):
        x_top , x_down = x_list[-1] , x_list[-1]
        i = len(x_list)-1
        for layer in self.top_layers:
            x_top = layer(x_top)

        for layer in self.down_layers:
            x_down = layer(x_down)
            if isinstance(layer , Concat_Block):
                x_ = F.resize(x_list[i] , (x_down.shape[-1] , x_down.shape[-1]))
                #print(x_down.shape , x_.shape)
                x_down = torch.cat([x_down , x_] , dim=1)
                i -= 1
            #print(x_down.shape)
        return x_top , x_down

In [None]:
'''def test():
    resnet = Resnet(3 , 1000).to(device)
    fpn = FPN().to(device)
    x = torch.randn(2 , 3 , 550 , 550).to(device)
    z = resnet(x)
    x_top , x_down = fpn(z)
    print(x_top.shape , x_down.shape)
test()'''

In [18]:
class Protnet(nn.Module):
    def __init__(self , 
                 in_channels = 512 , 
                 k_ = 4 , 
                 a = 5):
        super(Protnet , self).__init__()

        k = a * k_
        self.layers = nn.ModuleList()
        config_mask = [
                    # [out_channels , repeats]
                    [256 , 3] , 
                    256 , 
                    [k , 1]
        ]

        for layer in config_mask:
            if isinstance(layer , list):
                out_channels , repeats = layer
                self.layers.append(Conv(in_channels , out_channels))
                in_channels = out_channels

            elif isinstance(layer , int):
                out_channels = layer
                self.layers.append(
                    ConvT(in_channels , out_channels)
                )
        self.k = k_
        self.a = a
    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x.view(x.shape[0] , x.shape[2] , x.shape[3] , self.a , self.k)

In [None]:
'''x = torch.randn(2 , 512 , 68 , 68).to(device)
protnet = Protnet().to(device)
z = protnet(x)
z.shape'''

In [20]:
class Head(nn.Module):
    def __init__(self , 
                 in_channels = 512 , 
                 k=4 , 
                 a = 5 , 
                 c = 20):
        super(Head , self).__init__()
        hidden_dim = 256
        self.conv1 = ConvT(in_channels , hidden_dim)
        self.conv2 = ConvT(hidden_dim , hidden_dim)

        self.conv3 = Conv(hidden_dim , c * a + 4 * a + k * a)

        self.c = c
        self.a = a
        self.k = k

        self.upsample = nn.Upsample(scale_factor=2)
        self.tanh = nn.Tanh()

    def forward(self , x):
        i = 0
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        class_index = self.c * self.a
        bbox_index = 4 * self.a
        mask_index = self.k * self.a
        #print(class_index , bbox_index , mask_index)
        #print(class_index + bbox_index + mask_index)
        x_classes = x[:,:i+class_index , : , :]
        i += class_index
        x_bbox = x[: , class_index:i+bbox_index , : , :]
        i += mask_index
        x_mask = x[: , class_index+bbox_index:i+mask_index , : , :]
        x_mask = self.tanh(self.upsample(x_mask))

        return x_classes.view(x_classes.shape[0] , x_classes.shape[2] , x_classes.shape[3] , self.a , self.c) , x_bbox.view(x_bbox.shape[0] , x_bbox.shape[2] , x_bbox.shape[3] , self.a , 4) , x_mask.view(x_mask.shape[0] , x_mask.shape[2] , x_mask.shape[3] , self.a , self.k)

In [None]:
'''x = torch.randn(2 , 512 , 17 , 17).to(device)
head = Head().to(device)
z = head(x)
classes , bbox , mask = z
print(classes.shape , bbox.shape , mask.shape)'''

In [22]:
class NMS(nn.Module):
    def __init__(self , 
                 k = 4 , 
                 a = 5 , 
                 iou_threshold = 0.5):
        super(NMS , self).__init__()

        self.k = k
        self.iou_threshold = iou_threshold
        self.a = a

    def _get_top_k(self , flatten_tensor , sorted_tensors):
        out = []
        k_ = self.k * self.a
        #print(flatten_tensor.shape , sorted_tensors.shape)
        sorted_top_k = sorted_tensors[:k_]
        #print(sorted_top_k.shape)
        for idx in sorted_top_k:
            out.append(flatten_tensor[idx , :])
        return torch.stack(out)


    def forward(self , x , y):
        '''
        x => classes => [N , 68 , 68 , 5 , 20]
        y => bbox    => [N , 68 , 68 , 5 , 4]
        '''
        out = []
        final_classes = []
        final_bbox = []
        x = torch.argmax(x , dim=-1) # [N , 68 , 68 , 5]
        x_flatten = torch.flatten(x.float() , start_dim=1 , end_dim=-1) # [N , x ]
        y_flatten = torch.flatten(y , start_dim=1 , end_dim = -2) # [N , x , 4]
        
        for i in range(x.shape[0]):
            #print(x_flatten[i].shape , y_flatten[i].shape)
            nms_ = nms(y_flatten[i] , x_flatten[i] , self.iou_threshold)
            out.append(torch.tensor(nms_))
        for i in range(x.shape[0]):
            final_classes.append(self._get_top_k(x_flatten[i].unsqueeze(-1) , out[i]))
            final_bbox.append(self._get_top_k(y_flatten[i] , out[i]))
        final_bbox = torch.stack(final_bbox)
        final_classes = torch.stack(final_classes)
        return final_bbox.view(final_bbox.shape[0] , self.a , self.k , 4) , final_classes.view(final_classes.shape[0] , self.a , self.k , 1)

In [None]:
'''x = torch.randn(2 , 68 , 68 , 5 , 20)
y = torch.randn(2 , 68 , 68 , 5 , 4)
nms_ = NMS()
final_bbox , final_classes = nms_(x , y)
final_bbox.shape , final_classes.shape'''

In [24]:
class Yolact(nn.Module):
    def __init__(self , 
                 k = 4 , 
                 a = 5 , 
                 c = 1):
        super(Yolact , self).__init__()

        self.backbone = Resnet(3 , 1000)
        self.fpn = FPN()
        self.protnet = Protnet(k_ = k , a = a , in_channels=256)
        self.head = Head(k=k , a=a , c=c , in_channels=256)
        self.nms = NMS(k=k , a=a)

    def forward(self , x):
        x = self.backbone(x)
        x_top , x_down = self.fpn(x)
        classes , bbox , mask = self.head(x_top)
        prototypes = self.protnet(x_down)
        #bbox , classes = self.nms(classes , bbox)
        masks = prototypes + mask
        #print(prototypes.shape , bbox.shape , classes.shape , mask.shape)
        #print(masks.shape)
        classes = torch.argmax(classes , dim=-1)
        return classes.float().unsqueeze(-1) , bbox , masks

In [None]:
yolact = Yolact().to(device)
x = torch.randn(2 , 3 , 550 , 550)
classes , bbox , masks = yolact(x) 
print(classes.shape , bbox.shape , masks.shape)

In [26]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(self ,
                 img_dir , 
                 label_dir , 
                 csv_file , 
                 anchors , 
                 transforms = None , 
                 S = 68 , 
                 B = 5 , 
                 C = 20):
        super(Dataset_ , self).__init__()

        self.img_dir = img_dir
        self.label_dir = label_dir
        self.df = pd.read_csv(csv_file)
        self.anchors = torch.from_numpy(np.array(anchors))
        #print(self.anchors)
        self.transforms = transforms
        self.number_of_anchors_per_cell = 5
        self.ignore_iou_thresh = 0.5
        self.C = C
        self.S = S
        self.B = B
        self.mask_size = 136

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self , idx):
        img_size = 550
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []
        binary_mask = []
        label_mask = []

        img_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(img_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)
        transform_mask = transforms.Compose([
                                             transforms.ToPILImage() , 
                                             transforms.Resize((5 , 5)) , 
                                             transforms.Grayscale() , 
                                             transforms.ToTensor()
        ])

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])

        #label_mask = torch.tensor(label_mask)
        boxes = torch.tensor(boxes) 
        #binary_mask = torch.stack(binary_mask)

        if self.transforms:
            image = self.transforms(image)

        targets = torch.zeros((self.B , self.S , self.S , 5))
        target_mask = torch.zeros((self.B , self.mask_size , self.mask_size , 1))
        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending=True, dim=0)        
            x , y , width , height , class_label = box
            has_anchor = [False for _ in range(self.B)]
            for anchor_idx in anchors_indices:
                anchor_on_scale = anchor_idx % self.B
                S = self.S
                i , j = int(S * y) , int(S * x)
                anchor_taken = targets[anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[anchor_on_scale]:
                    targets[anchor_on_scale , i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinate = torch.tensor([x_cell , y_cell , width_cell , height_cell])
                    targets[anchor_on_scale , i , j , :4] = box_coordinate
                    targets[anchor_on_scale , i , j , 4] = int(class_label)
                    target_mask_ = F.crop(image , int(x_cell) , int(y_cell) , int(width_cell) , int(height_cell))
                    target_mask_ = transform_mask(target_mask_)
                    #print(target_mask_.permute(1 , 2 , 0).shape)
                    #print(target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1].shape)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1] = target_mask_.permute(1 , 2 , 0)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 1:2] = int(class_label)
                    has_anchor[anchor_on_scale] = True
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[anchor_on_scale , i , j , 0] = -1
        return image , targets.view(self.S , self.S , self.B , 5) , target_mask.view(self.mask_size , self.mask_size , self.B , 1)

In [27]:
anchors = [[ 0.28, 0.22], [  0.38, 0.48], [ 0.9, 0.78], [ 0.07, 0.15], [ 0.15, 0.11]]
transform = transforms.Compose([
                                transforms.ToPILImage() , 
                                transforms.Resize((550 , 550)) , 
                                transforms.ToTensor()
])
dataset = Dataset_(
    img_dir = '/content/drive/MyDrive/Yolo_Dataset/images/' , 
    label_dir = '/content/drive/MyDrive/Yolo_Dataset/labels' , 
    csv_file = '/content/drive/MyDrive/Yolo_Dataset/train.csv' , 
    anchors = anchors , 
    transforms = transform
)
dataloader = torch.utils.data.DataLoader(dataset , batch_size = 1 , shuffle=True)

In [None]:
for x , y , z in dataloader:
    show_tensor_images(x)
    print(y.shape)
    print(z.shape)
    break

In [29]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
ce_criterion = nn.CrossEntropyLoss()
lambda_recon = 200
betas = (0.5 , 0.999)


n_epochs = 200
display_steps = 1
lr = 0.002

In [30]:
yolact = Yolact().to(device)
opt = torch.optim.Adam(yolact.parameters() , lr=lr , betas=betas)

In [32]:
def train():
    mean_yolact_loss = 0
    cur_step = 0
    for epoch in range(n_epochs):
        for img , label ,  mask_label in dataloader:
            img , label , mask_label = img.to(device) , label.to(device)  , mask_label.to(device)

            opt.zero_grad()
            cls_ , bbox , mask  = yolact(img)
            #print(cls.shape , bbox.shape) # label => [N , 68 , 68 , 5 , 5]
                                           # mask_label => [N , 136 , 136 , 5 , 4]
                                           # cls_ => [N , 68  ,68 , 5 , 20] 
                                           # bbox => [N , 68 , 68 , 5 , 4]
                                           # mask => [N , 136 , 136 , 5 , 4]

            loss_1 = recon_criterion(cls_ , label[... , 4:5])
            loss_2 = recon_criterion(mask , mask_label)
            loss_3 = recon_criterion(bbox , label[... , :4])

            loss = (loss_1 + loss_2 + loss_3) / 3
            loss.backward()
            opt.step()

            mean_yolact_loss += loss.item() / display_steps
            if cur_step % display_steps == 0:
                print(f'Epoch {epoch} , Step {cur_step} , Mean Yolact Loss {mean_yolact_loss}')
            cur_step +=1
        mean_yolact_loss = 0

In [None]:
train()