In [1]:
####

In [2]:
import torch
from torch import nn
import torch.nn.functional as f

In [3]:

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from torchvision import transforms
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from tqdm.notebook import tqdm
from torchvision.ops import roi_align

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

In [6]:

def show_tensor_images(image_tensor, num_images=2, size=(3 , 224 , 224)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [7]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [8]:
def Reverse(lst):
    return [ele for ele in reversed(lst)]

In [9]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels ,
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_pool:
            self.max_pool = nn.MaxPool2d(kernel_size=2 , stride=2)
        
    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_pool:
            x = self.max_pool(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [10]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride ,
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [11]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample

        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        

        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [12]:
class Resnet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 64 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(64 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 8)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 36)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        #self.linear = Linear(2048 , out_channels)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x = self.conv1(x)
        x = torch.max_pool2d(x , kernel_size = (2 , 2) , stride = (2 , 2))
        x_0 = self.conv2(x)
        x_1 = self.conv3(x_0)
        x_2 = self.conv4(x_1)
        x_3 = self.conv5(x_2)
        #x_out = [x_0 , x_1 , x_2 , x_3]
        x_out = [x_3 , x_2 , x_1 , x_0]
        return x_out

In [13]:

class PAN_Net(nn.Module):
    def __init__(self , 
                 in_channels = [2048 , 1024 , 512 , 256] ,
                 out_channels = [128 , 256 , 512 , 1024]):
        super(PAN_Net , self).__init__()


        self.top_down = nn.ModuleList()
        self.bottom_up = nn.ModuleList()
        self.resnet = Resnet(3 , 1000)
        
        for channel in in_channels:
            out_channel = channel // 2
            self.top_down.append(ConvT(channel , out_channel))
        
        for channel in out_channels:
            out_channel = channel * 2
            self.bottom_up.append(Conv(channel , out_channel , use_pool=True))

    def _get_pooled(self , x):
        for i , x_ in enumerate(x):
            x[i] = f.adaptive_avg_pool2d(x_ , (14 , 14))
        x1 , x2 , x3 , x4 = x
        x = torch.cat([x1 , x2 , x3 , x4] , dim=1)
        return x

    def forward(self , x):
        x = self.resnet(x)
        p = []
        N = []
        x0 , x1 , x2 , x3 = x
        for i , layer in enumerate(self.top_down):
            p.append(layer(x[i]))
            if i !=0 and i!= len(self.top_down)-1:
                p[i] = p[i] + x[i+1]

        p_ = Reverse(p)
        for i , layer in enumerate(self.bottom_up):
            N.append(layer(p_[i]))
            if i != 0 and i!= len(self.top_down)-1:
                N[i] = N[i] + p_[i+1]
        N = self._get_pooled(N)
        return N

In [14]:
def test():
    pan_net = PAN_Net().to(device)
    x = torch.randn(2 , 3 , 224 , 224).to(device)
    z = pan_net(x)
    print(z.shape)

In [15]:
#test()

In [16]:
class Mask_Branch(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_model = 1):
        super(Mask_Branch , self).__init__()
        hidden_dim = 256
        self.conv1 = Conv(in_channels , hidden_dim)
        self.conv2 = Conv(hidden_dim , hidden_dim)
        self.conv3 = Conv(hidden_dim , hidden_dim)
        self.conv4 = Conv(hidden_dim , hidden_dim)

        self.convT1 = ConvT(hidden_dim , 1)

        self.conv4_fc = Conv(hidden_dim , hidden_dim //2)
        hidden_dim = hidden_dim //2
        self.conv5_fc = Conv(hidden_dim , hidden_dim // 2)
        self.flatten = nn.Flatten()

        self.linear1 = nn.Linear(12544 , 784 * 5)

    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = x.clone()
        x = self.conv4(x)
        x_out = self.convT1(x)
        #print(x_out.shape)
        x_ = self.conv4_fc(x_)
        x_ = self.conv5_fc(x_)
        x_ = self.flatten(x_)
        x_ = self.linear1(x_.squeeze(-1).squeeze(-1))
        #print(x_.shape , x_out.shape)
        x_ = x_.view(x_.shape[0] , 1 , 28 , 28 , 5)
        x = x_out.unsqueeze(-1) + x_
        return x.view(x.shape[0] , 5 , 1 , 28 , 28)

In [None]:
'''x = torch.randn(2 , 3840 , 14 , 14).to(device)
mask_pred = Mask_Branch(3840).to(device)
z = mask_pred(x)
z.shape'''

In [18]:
class Pred(nn.Module):
    def __init__(self , 
                 in_channels , 
                 num_classes = 1 , 
                 B = 5 , 
                 S = 28):
        super(Pred , self).__init__()

        self.conv1 = Conv(in_channels , in_channels // 2)
        self.conv2 = Conv(in_channels // 2 , in_channels // 4)
        self.convT1 = ConvT(in_channels // 4 , in_channels //2)
        num_classes_ = num_classes + 4
        self.conv3 = Conv(in_channels //2 , num_classes_ * B)

        self.B = B
        self.S = S
        self.num_classes = num_classes_
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.convT1(x)
        x = self.conv3(x)
        return x.view(x.shape[0] , self.B , self.S , self.S , self.num_classes)

In [None]:
'''x = torch.randn(2 , 3840  , 14 , 14).to(device)
pred = Pred(3840).to(device)
z = pred(x)
z.shape'''

In [20]:
class PAN_Net_Model(nn.Module):
    def __init__(self , 
                 in_channels = 3):
        super(PAN_Net_Model , self).__init__()

        self.pannet = PAN_Net()
        self.class_branch = Pred(3840)
        self.mask_branch = Mask_Branch(3840)

    def forward(self , x):
        x = self.pannet(x)
        x_classes = self.class_branch(x)
        x_mask = self.mask_branch(x)
        return x_classes , x_mask

In [None]:
'''x = torch.randn(2 , 3 , 224 , 224).to(device)
pannet = PAN_Net_Model().to(device)
z = pannet(x)
classes , mask = z
print(classes.shape , mask.shape)'''

In [22]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(self ,
                 img_dir , 
                 label_dir , 
                 csv_file , 
                 anchors , 
                 transforms = None , 
                 S = 28 , 
                 B = 5 , 
                 C = 20):
        super(Dataset_ , self).__init__()

        self.img_dir = img_dir
        self.label_dir = label_dir
        self.df = pd.read_csv(csv_file)
        self.anchors = torch.from_numpy(np.array(anchors))
        #print(self.anchors)
        self.transforms = transforms
        self.number_of_anchors_per_cell = 5
        self.ignore_iou_thresh = 0.5
        self.C = C
        self.S = S
        self.B = B
        self.mask_size = 28

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self , idx):
        img_size = 224
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []
        binary_mask = []
        label_mask = []

        img_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(img_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)
        transform_mask = transforms.Compose([
                                             transforms.ToPILImage() , 
                                             transforms.Resize((5 , 5)) , 
                                             transforms.Grayscale() , 
                                             transforms.ToTensor()
        ])

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])
                '''
                i_ , j_ = int(img_size * y) , int(img_size * x)
                x = img_size * x - j_
                y = img_size * y - i_
                height = img_size * height 
                width = img_size * width 
                img_ = F.crop(image , int(x) , int(y) , int(width) , int(height))
                img_ = transform_mask(img_)
                binary_mask.append(img_)
                label_mask.append(class_label)
                '''
        #label_mask = torch.tensor(label_mask)
        boxes = torch.tensor(boxes) 
        #binary_mask = torch.stack(binary_mask)

        if self.transforms:
            image = self.transforms(image)

        targets = torch.zeros((self.B , self.S , self.S , 5))
        target_mask = torch.zeros((self.B , self.S , self.S , 1))
        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending=True, dim=0)        
            x , y , width , height , class_label = box
            has_anchor = [False for _ in range(self.B)]
            for anchor_idx in anchors_indices:
                anchor_on_scale = anchor_idx % self.B
                S = self.S
                i , j = int(S * y) , int(S * x)
                anchor_taken = targets[anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[anchor_on_scale]:
                    targets[anchor_on_scale , i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinate = torch.tensor([x_cell , y_cell , width_cell , height_cell])
                    targets[anchor_on_scale , i , j , :4] = box_coordinate
                    targets[anchor_on_scale , i , j , 4] = int(class_label)
                    target_mask_ = F.crop(image , int(x_cell) , int(y_cell) , int(width_cell) , int(height_cell))
                    target_mask_ = transform_mask(target_mask_)
                    #print(target_mask_.permute(1 , 2 , 0).shape)
                    #print(target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1].shape)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 0:1] = target_mask_.permute(1 , 2 , 0)
                    target_mask[anchor_on_scale , i:i+5 , j:j+5 , 1:2] = int(class_label)
                    has_anchor[anchor_on_scale] = True
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[anchor_on_scale , i , j , 0] = -1
        return image , targets , target_mask.view(self.B , 1 , self.S , self.S)

In [23]:
anchors = [[ 0.28, 0.22], [  0.38, 0.48], [ 0.9, 0.78], [ 0.07, 0.15], [ 0.15, 0.11]]
transform = transforms.Compose([
                                transforms.ToPILImage() , 
                                transforms.Resize((224 , 224)) , 
                                transforms.ToTensor()
])
dataset = Dataset_(
    img_dir = '/content/drive/MyDrive/Yolo_Dataset/images/' , 
    label_dir = '/content/drive/MyDrive/Yolo_Dataset/labels' , 
    csv_file = '/content/drive/MyDrive/Yolo_Dataset/train.csv' , 
    anchors = anchors , 
    transforms = transform
)
dataloader = torch.utils.data.DataLoader(dataset , batch_size = 1 , shuffle=True)

In [None]:
for x , y , z in dataloader:
    show_tensor_images(x)
    print(y.shape)
    print(z.shape)
    break

In [25]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
ce_criterion = nn.CrossEntropyLoss()
lambda_recon = 200
betas = (0.5 , 0.999)


n_epochs = 200
display_steps = 1
lr = 0.002

In [26]:
pannet = PAN_Net_Model().to(device)
opt = torch.optim.Adam(pannet.parameters() , lr=lr , betas = betas)

In [27]:
def train():
    mean_pannet_loss = 0
    cur_step = 0
    for epoch in range(n_epochs):
        for img , label ,  mask_label in dataloader:
            img , label , mask_label = img.to(device) , label.to(device)  , mask_label.to(device)

            opt.zero_grad()
            cls_ , mask  = pannet(img)
            #print(cls.shape , bbox.shape)
            loss_1 = recon_criterion(cls_ , label)
            loss_2 = recon_criterion(mask , mask_label)

            loss = (loss_1 + loss_2) /2
            loss.backward()
            opt.step()

            mean_pannet_loss += loss.item() / display_steps
            if cur_step % display_steps == 0:
                print(f'Epoch {epoch} , Step {cur_step} , Mean Pannet Loss {mean_pannet_loss}')
            cur_step +=1
        mean_pannet_loss = 0


In [None]:
train()