###Downloading and Importing Packages

In [None]:
from google.colab import drive
drive.mount('/content/drive') 


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
import torch.nn.functional as F

from torch import nn
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torchvision
from functools import partial
import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from torch.utils.data import random_split, DataLoader
from matplotlib.patches import Rectangle
import random
from torch.autograd import Variable


###utils

In [None]:
def MultiApply(func, *args, **kwargs):
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
  
    return tuple(map(list, zip(*map_results)))

# This function compute the IOU between two set of boxes 
def IOU(output, target):
    ##################################
    #TODO compute the IOU between the boxA, boxB boxes
    ##################################
  
    output_box_x1 = output[:, 0] - output[:, 2] / 2
    output_box_y1 = output[:, 1] - output[:, 3] / 2
    output_box_x2 = output[:, 0] + output[:, 2] / 2
    output_box_y2 = output[:, 1] + output[:, 3] / 2
    target_box_x1 = target[:, 0] - target[:, 2] / 2
    target_box_y1 = target[:, 1] - target[:, 3] / 2
    target_box_x2 = target[:, 0] + target[:, 2] / 2
    target_box_y2 = target[:, 1] + target[:, 3] / 2

    x1 = torch.max(output_box_x1, target_box_x1)
    y1 = torch.max(output_box_y1, target_box_y1)
    x2 = torch.min(output_box_x2, target_box_x2)
    y2 = torch.min(output_box_y2, target_box_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    output_box_area = target[:, 2] * target[:, 3] 
    target_box_area = output[:, 2] * output[:, 3] 

    return intersection / (output_box_area + target_box_area - intersection + 1e-6)


# This function decodes the output of the box head that are given in the [t_x,t_y,t_w,t_h] format
# into box coordinates where it return the upper left and lower right corner of the bbox
# Input:
#       regressed_boxes_t: (total_proposals,4) ([t_x,t_y,t_w,t_h] format)
#       flatten_proposals: (total_proposals,4) ([x1,y1,x2,y2] format)
# Output:
#       box: (total_proposals,4) ([x1,y1,x2,y2] format)
def output_decoding(out,anchors, device='cpu'):
    #######################################
    # TODO decode the output
    #######################################
    anchors = anchors.cpu()
    box = torch.zeros(out.shape)

    box[0,:,:] = out[0,:,:]*anchors[:,:,2] + anchors[:,:,0]
    box[1,:,:] = out[1,:,:]*anchors[:,:,3] + anchors[:,:,1]
    box[2,:,:] = torch.exp(out[2,:,:])*anchors[:,:,2]
    box[3,:,:] = torch.exp(out[3,:,:])*anchors[:,:,3]
    
    return box

def flattened_output_decoding(flatten_out,flatten_anchors, device='cpu'):
    #######################################
    # TODO decode the output
    #######################################
    flatten_anchors = flatten_anchors.cpu()
    box = torch.zeros(flatten_out.shape)
    box[:, 0] = (flatten_out[:, 0] * flatten_anchors[:, 2]) + flatten_anchors[:,0]
    box[:, 1] = (flatten_out[:, 1] * flatten_anchors[:, 3]) + flatten_anchors[:,1]
    box[:, 2] = torch.exp(flatten_out[:, 2]) * flatten_anchors[:, 2]
    box[:, 3] = torch.exp(flatten_out[:, 3]) * flatten_anchors[:, 3]
    
    return box


def pretrained_models_680(checkpoint_file,eval=True):

    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

    if(eval):
        model.eval()

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    backbone = model.backbone
    rpn = model.rpn

    if(eval):
        backbone.eval()
        rpn.eval()

    rpn.nms_thresh=0.6
    checkpoint = torch.load(checkpoint_file)

    backbone.load_state_dict(checkpoint['backbone'])
    rpn.load_state_dict(checkpoint['rpn'])

    return backbone, rpn

### Defining dataloaders

In [None]:
def load_data(paths):
  images_file = h5py.File(paths[0],'r') 
  masks_file = h5py.File(paths[1],'r')
  

  # images = np.array(images_file['data'])
  # masks = np.array(masks_file['data'])
  images = images_file['data']
  masks = masks_file['data']
  bboxes = np.load(paths[3], allow_pickle=True)
  labels = np.load(paths[2], allow_pickle=True)

  label_based_masks = []
  cur_pointer = 0
  for label in labels:
    cur_mask = np.empty((len(label), 300, 400))

    for i in range(len(label)):
      cur_mask[i] = masks[cur_pointer]
      cur_pointer += 1
    label_based_masks.append(cur_mask)

  label_based_masks = np.array(label_based_masks)

  return images, label_based_masks, labels, bboxes

def collate_fn(batch):
    images, labels, masks, bounding_boxes, indices, old_mask = list(zip(*batch))
    return torch.stack(images), labels, masks, bounding_boxes, indices, old_mask

class BuildDataset(Dataset):
    def __init__(self, path, image_transform=None, mask_transform = None):
        #############################################
        # TODO Initialize  Dataset
        #############################################
        self.images, self.masks, self.labels, self.bboxes = load_data(path)

        self.image_transform = image_transform
        self.mask_transform = mask_transform

    # In this function for given index we rescale the image and the corresponding  masks, boxes
    # and we return them as output
    # output:
        # transed_img
        # label
        # transed_mask
        # transed_bbox
        # index
    def __getitem__(self, idx):
        ################################
        # TODO return transformed images,labels,masks,boxes,index
        ################################
        image = self.images[idx].astype('uint8').transpose(1,2,0)
        label = self.labels[idx].astype('float32')
        mask = torch.tensor(self.masks[idx].astype('float64'))
        bbox = self.bboxes[idx].astype('float32')
        
        if self.image_transform:

          image = self.image_transform(image)
          mask = self.mask_transform(mask)
        #   x_scale = 1066.0 / 400.0
        #   y_scale =  800.0 / 300.0
        #   for i, box in enumerate(bbox):
        #     bbox[i][0] = box[0] * x_scale + 11.0
        #     bbox[i][1] = box[1] * y_scale
        #     bbox[i][2] = box[2] * x_scale + 11.0
        #     bbox[i][3] = box[3] * y_scale
          x_scale = 800.0 / 300.0
          y_scale = 1088.0 / 400.0
          for i, box in enumerate(bbox):
            bbox[i][0] = box[0] * x_scale
            bbox[i][1] = box[1] * y_scale
            bbox[i][2] = box[2] * x_scale
            bbox[i][3] = box[3] * y_scale
        old_mask = mask.clone()
        
        # if self.image_transform:
        #   image = self.image_transform(image)
        #   mask = self.mask_transform(mask)
        #   x_scale = 800.0 / 300.0
        #   y_scale =  1088.0 / 400.0
        #   for i, box in enumerate(bbox):
        #     bbox[i][0] = box[0] * x_scale
        #     bbox[i][1] = box[1] * y_scale
        #     bbox[i][2] = box[2] * x_scale
        #     bbox[i][3] = box[3] * y_scale

        bbox = torch.tensor(bbox, device = device)

        target = torch.zeros(3, 28, 28).to(device)
        label_once = []
        # crop mask as bbox
        for idxe in range(bbox.shape[0]):

            # only use the first occurred label
            labe = int(label[idxe])-1
            if labe in label_once:                             
                continue
            label_once.append(labe)

            x1, y1, x2, y2 = bbox[idxe].type(torch.int)        # bbox [4,]
            mas = mask[idxe]                                   # mask [800, 1088]
            mask_cropped = mas[y1:y2, x1:x2]                  
            mask_unsquez = mask_cropped.unsqueeze(0)
            mask_unsquez = mask_unsquez.unsqueeze(0)

            # Resize cropped mask to 28 * 28
            mask_resized = F.interpolate(mask_unsquez, (28, 28), mode='bilinear', align_corners=True)   # [1, 1, 28, 28]
            positive_mask = torch.where(mask_resized > 0)
            # print(label)
            mask_resized[positive_mask] = 1
            target[labe] = mask_resized.squeeze()

        mask_target = target.unsqueeze(0).expand(50, -1, -1, -1)


        assert image.shape == (3,800,1088)
        assert bbox.shape[0] == mask.shape[0]
        
        return image, label, mask_target, bbox, idx, old_mask

    # This function preprocess the given image, mask, box by rescaling them appropriately
    # output:
    #        img: (3,800,1088)
    #        mask: (n_box,800,1088)
    #        box: (n_box,4)
    # def pre_process_batch(self, img, mask, bbox):
    #     #######################################
    #     # TODO apply the correct transformation to the images,masks,boxes
    #     ######################################

    #     assert img.squeeze(0).shape == (3, 800, 1088)
    #     assert bbox.shape[0] == mask.squeeze(0).shape[0]

    #     return img.squeeze(0), mask.squeeze(0), bbox
    
    def __len__(self):
        return self.labels.shape[0]


class BuildDataLoader(LightningDataModule):
    def __init__(self, dataset, batch_size=32):
        super().__init__()

        self.dataset = dataset
        self.batch_size = batch_size

        return

    def setup(self, stage=None):
        test_split = int(0.2 * len(self.dataset))  # 5% of data to be used as validation set
        val_split = int(0.05 * len(self.dataset))
        self.train_data, self.val_data, self.test_data = random_split(self.dataset, [len(self.dataset)-test_split-val_split, val_split, test_split])

        return
        

    def train_dataloader(self):
        
        # Generating train_dataloader
        return DataLoader(self.train_data, batch_size = self.batch_size, collate_fn=collate_fn, shuffle = True)
    
    def val_dataloader(self):
        
        # Generating train_dataloader
        return DataLoader(self.val_data, batch_size = self.batch_size, collate_fn=collate_fn, shuffle = True)
    
    def predict_dataloader(self):
        
        # Generating val_dataloader
        return DataLoader(self.test_data, batch_size = 1, collate_fn=collate_fn, shuffle = False)


###BoxHead Utils

In [None]:
def iou(anchors, gt):               #Dimensions Anchors: (n_proposals,4), gt (ground_truth_boxes, 4)
    #Extracting centers and h w
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    px1 =  anchors[:,0].reshape(-1,1)                                            #[n_proposals,1]                                        
    py1 =  anchors[: ,1].reshape(-1,1)                                           #[n_proposals,1]
    px2  = anchors[: ,2].reshape(-1,1)                                           #[n_proposals,1]
    py2  = anchors[: ,3].reshape(-1,1)                                           #[n_proposals,1]
    
    
    gx1  = gt[:,0].reshape(-1,1)                                                 #[ground_truth_boxes,1]
    gy1  = gt[:,1].reshape(-1,1)                                                 #[ground_truth_boxes,1]
    gx2  = gt[:,2].reshape(-1,1)                                                 #[ground_truth_boxes,1] 
    gy2  = gt[:,3].reshape(-1,1)                                                 #[ground_truth_boxes,1]
    
    #Box format [x1 y1 x2 y2]
    box1 =  [px1, py1, px2, py2]
    box2 =  [gx1, gy1, gx2, gy2]

    xA = torch.max(box1[0], box2[0].T)                                            #[n_proposals,ground_truth_boxes] 
    yA = torch.max(box1[1], box2[1].T)                                            #[n_proposals,ground_truth_boxes] 
  
    xB = torch.min(box1[2], box2[2].T)                                            #[n_proposals,ground_truth_boxes] 
    yB = torch.min(box1[3], box2[3].T)                                            #[n_proposals,ground_truth_boxes] 
  
    area_intersection = torch.max(xB-xA, torch.zeros(xB.shape, dtype=xB.dtype,device = device)) * torch.max(yB-yA, torch.zeros(yB.shape, dtype=yB.dtype, device = device))
  
    area_union = (box1[2]-box1[0]) * (box1[3]-box1[1]) + ((box2[2]-box2[0]) * (box2[3]-box2[1])).T - area_intersection
  
    iou = torch.div(area_intersection+1,area_union+1)
    return iou   #[n_proposals, ground_truth_boxes]


# This function decodes the output of the box head that are given in the [t_x,t_y,t_w,t_h] format
# into box coordinates where it return the upper left and lower right corner of the bbox
# Input:
#       regressed_boxes_t: (total_proposals,4) ([t_x,t_y,t_w,t_h] format)
#       flatten_proposals: (total_proposals,4) ([x1,y1,x2,y2] format)
# Output:
#       box: (total_proposals,4) ([x1,y1,x2,y2] format)

def output_decodingd(regressed_boxes_t,flatten_proposals, device='cpu'):
	wp = flatten_proposals[:,2] - flatten_proposals[:,0]
	hp = flatten_proposals[:,3] - flatten_proposals[:,1]

	x_p = (flatten_proposals[:,2] + flatten_proposals[:,0])/2
	y_p = (flatten_proposals[:,3] + flatten_proposals[:,1])/2

	box= torch.zeros(regressed_boxes_t.shape, device=device)
	box[:,0] = regressed_boxes_t[:,0]*wp + x_p - torch.exp(regressed_boxes_t[:,2])*wp/2
	box[:,1] = regressed_boxes_t[:,1]*hp + y_p - torch.exp(regressed_boxes_t[:,3])*hp/2
	box[:,2] = regressed_boxes_t[:,0]*wp + x_p + torch.exp(regressed_boxes_t[:,2])*wp/2
	box[:,3] = regressed_boxes_t[:,1]*hp + y_p + torch.exp(regressed_boxes_t[:,3])*hp/2

	return box

###Boxhead Definition

In [None]:
#BoxHead.py

class BoxHead(pl.LightningModule):
    def __init__(self,Classes=3,P=7, eval_ = False):
        super(BoxHead,self).__init__()
        self.C=Classes
        self.P=P
        self.train_losses = []
        self.validation_losses = []
        self.eval_ = eval_
        # TODO initialize BoxHead

        self.intermediate_layer = nn.Sequential(
                                    nn.Linear(in_features=256*self.P*self.P, out_features=1024),
                                    nn.ReLU(),
                                    nn.Linear(in_features=1024, out_features=1024),
                                    nn.ReLU()
                                    )

        self.classifier_head = nn.Sequential(
                                  nn.Linear(in_features=1024, out_features=self.C+1)
                                  )
        
        self.regressor_head = nn.Sequential(
                                  nn.Linear(in_features=1024, out_features=4*(self.C))
            
                                  )
        
        # self.mask_head = nn.Sequential(
        #                           nn.Conv2d(256, 256, kernel_size=3, padding='same'),
        #                           nn.ReLU(),
        #                           nn.Conv2d(256, 256, kernel_size=3, padding='same'),
        #                           nn.ReLU(),
        #                           nn.Conv2d(256, 256, kernel_size=3, padding='same'),
        #                           nn.ReLU(),
        #                           nn.Conv2d(256, 256, kernel_size=3, padding='same'),
        #                           nn.ReLU(),
        #                           nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2),
        #                           nn.ReLU(),
        #                           nn.Conv2d(256, 4, kernel_size=1),
        #                           nn.Sigmoid()
        #                           )
                                  
        
        for layer in self.intermediate_layer:
          self._init_weights(layer)
        for layer in self.classifier_head:
          self._init_weights(layer)
        for layer in self.regressor_head:
          self._init_weights(layer)
        # for layer in self.mask_head:
        #   self._init_weights(layer)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None:
                module.bias.data.zero_()

    # Forward the pooled feature vectors through the intermediate layer and the classifier, regressor of the box head
    # Input:
    #        feature_vectors: (total_proposals, 256*P*P)
    # Outputs:
    #        class_logits: (total_proposals,(C+1)) (we assume classes are C classes plus background, notice if you want to use
    #                                               CrossEntropyLoss you should not pass the output through softmax here)
    #        box_pred:     (total_proposals,4*C)
    def forward(self, feature_vectors):
      
      X = self.intermediate_layer(feature_vectors)

      class_logits = self.classifier_head(X)

      # if self.eval_:
      softmax = nn.Softmax()
      class_logits = softmax(class_logits)

      box_pred = self.regressor_head(X)

      return class_logits, box_pred

    # def forward_mask(self, x):

    #   x = self.mask_head(x)

    #   return x


    #  This function assigns to each proposal either a ground truth box or the background class (we assume background class is 0)
    #  Input:
    #       proposals: list:len(bz){(per_image_proposals,4)} ([x1,y1,x2,y2] format)
    #       gt_labels: list:len(bz) {(n_obj)}
    #       bbox: list:len(bz){(n_obj, 4)}
    #  Output: (make sure the ordering of the proposals are consistent with MultiScaleRoiAlign)
    #       labels: (total_proposals,1) (the class that the proposal is assigned)
    #       regressor_target: (total_proposals,4) (target encoded in the [t_x,t_y,t_w,t_h] format)
    def create_ground_truth(self,proposals,gt_labels,bboxes):
      labels = []
      regressor_target = []

      for cur_batch, cur_image_proposals_ in enumerate(proposals):
        cur_image_max_iou = 0
        cur_proposal_label = 0
        cur_image_labels = []
        cur_image_regressor_targets = []

        cur_image_proposals = torch.zeros_like(cur_image_proposals_, device = device)
        cur_image_proposals[:,0] = (cur_image_proposals_[:,0] + cur_image_proposals_[:,2]) / 2
        cur_image_proposals[:,1] = (cur_image_proposals_[:,1] + cur_image_proposals_[:,3]) / 2
        cur_image_proposals[:,2] = cur_image_proposals_[:,2] - cur_image_proposals_[:,0] 
        cur_image_proposals[:,3] = cur_image_proposals_[:,3] - cur_image_proposals_[:,1]

        cur_bbox = torch.zeros_like(bboxes[cur_batch], device = device)
        cur_bbox[:,0] = (bboxes[cur_batch][:,0] + bboxes[cur_batch][:,2]) / 2
        cur_bbox[:,1] = (bboxes[cur_batch][:,1] + bboxes[cur_batch][:,3]) / 2
        cur_bbox[:,2] = bboxes[cur_batch][:,2] - bboxes[cur_batch][:,0] 
        cur_bbox[:,3] = bboxes[cur_batch][:,3] - bboxes[cur_batch][:,1]

        for cur_proposal_ind, cur_proposal in enumerate(cur_image_proposals):
          
          replicated_proposal = cur_proposal.repeat((len(cur_bbox), 1))
          cur_IOU = IOU(cur_bbox.to(device), replicated_proposal.to(device)) # send (x,y,w,h)

          max_iou, max_iou_ind = torch.max(cur_IOU, dim = 0)

          if max_iou > 0.5:
            cur_proposal_label = gt_labels[cur_batch][max_iou_ind]
            
            cur_proposal_bbox = cur_bbox[max_iou_ind]
          else:
            cur_proposal_label = 0
            cur_proposal_bbox = torch.tensor([1,2,3,4], device = device) #TODO

          labels.append(cur_proposal_label)

          box = torch.zeros_like(cur_proposal_bbox)

          box[0] = ((cur_proposal_bbox[0]-cur_proposal[0])/cur_proposal[2])
          box[1] = ((cur_proposal_bbox[1]-cur_proposal[1])/cur_proposal[3])
          box[2] = (torch.log(cur_proposal_bbox[2] / cur_proposal[2]))
          box[3] = (torch.log(cur_proposal_bbox[3] / cur_proposal[3]))

          regressor_target.append(box)

      regressor_target = torch.stack(regressor_target)
      return [int(x) for x in labels], regressor_target



    # This function for each proposal finds the appropriate feature map to sample and using RoIAlign it samples
    # a (256,P,P) feature map. This feature map is then flattened into a (256*P*P) vector
    # Input:
    #      fpn_feat_list: list:len(FPN){(bz,256,H_feat,W_feat)}
    #      proposals: list:len(bz){(per_image_proposals,4)} ([x1,y1,x2,y2] format)
    #      P: scalar
    # Output:
    #      feature_vectors: (total_proposals, 256*P*P)  (make sure the ordering of the proposals are the same as the ground truth creation)
    def MultiScaleRoiAlign(self, fpn_feat_list,proposals,P=7):
        #####################################
        # Here you can use torchvision.ops.RoIAlign check the docs
        #####################################

        image_size = (800, 1088) # (h, w)
        feature_vectors = []

        num_proposals = len(proposals)
        for i in range(num_proposals):
            cur_proposal = proposals[i]
            for j in range(cur_proposal.shape[0]):
              width = cur_proposal[j][2] - cur_proposal[j][0]
              height = cur_proposal[j][3] - cur_proposal[j][1]
              K = torch.clip(torch.floor( 4 + torch.log2(torch.sqrt(width*height)/224)),2,5).int()
          
              box = cur_proposal[j].clone()
              scale_x = image_size[1]/fpn_feat_list[K-2].shape[3]
              scale_y = image_size[0]/fpn_feat_list[K-2].shape[2]

              box = box.reshape(1,-1)
              box[:,0] = box[:,0] / scale_x
              box[:,1] = box[:,1] / scale_y 
              box[:,2] = box[:,2] / scale_x
              box[:,3] = box[:,3] / scale_y
            
              feature_map = fpn_feat_list[K-2][i].unsqueeze(0) 
              aligned_box_map  = torchvision.ops.roi_align(feature_map, [box], output_size=P, spatial_scale=1, sampling_ratio=-1)
              aligned_box_map = aligned_box_map.flatten()

              feature_vectors.append(aligned_box_map)
        
        feature_vectors = torch.stack(feature_vectors, dim=0)

        return feature_vectors


    def training_step(self, batch, batch_idx):
      images, labels, masks, bboxes, indices = batch
      # logits, bbox_regs = self.forward(images.to(device))
      # ground_clas, ground_coord = self.create_batch_truth(bboxes, indices, self.image_size)
      masks = list(masks)
      images = images.to(device)
      backout = backbone(images)

      # The RPN implementation takes as first argument the following image list
      im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
      # Then we pass the image list and the backbone output through the rpn
      rpnout = rpn(im_lis, backout)
      proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
      # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
      fpn_feat_list= list(backout.values())

      feature_vectors = self.MultiScaleRoiAlign(fpn_feat_list, proposals)

    #   labels, regressor_targets = model.create_ground_truth(proposals, labels, bboxes)
      # labels, regressor_target = labels[:keep_topK], regressor_target[:keep_topK]
      # print(len(feature))
      # print(feature_vectors.shape)
      class_logits, box_preds = self.forward(feature_vectors)
      # print(box_preds.shape)

      '''
      proposals: x1
      '''

      labels = torch.tensor(labels)
      loss, loss_c, loss_r = self.compute_loss(class_logits, box_preds, labels, regressor_targets,l=10,effective_batch=150)

      boxes = []
      # scores = []
      # labels = []
      for i, proposal in enumerate(proposals):
        class_logit = class_logits[i*keep_topK: i*keep_topK + keep_topK]
        box_pred = box_preds[i*keep_topK: i*keep_topK + keep_topK]
        score, box, _ = self.postprocess_detections(images.squeeze(0), class_logits = class_logit, box_regression = box_pred, proposals = proposal, conf_thresh = 0.5, keep_num_preNMS=500, keep_num_postNMS=100, plot_=False)
        # print(box)
        boxes.append(box)

      # print(boxes[0].shape)
      # print(boxes)
      if boxes[0] is not None:
        feature_vectors_2 = self.MaskMultiScaleRoiAlign(fpn_feat_list, boxes)

        # # mask_pred =  self.forward_mask(feature_vectors_2.squeeze(0))
        # if mask_pred.shape[0] < masks[0].shape[0]:
        #   masks[0] = masks[0][:mask_pred.shape[0]]
        # print(mask_pred.shape)
        loss_m = self.loss_bce(mask_pred, masks[0])
        loss = loss + loss_m

        self.log("train_class_loss", loss_c, prog_bar=True)
        self.log("train_regr_loss", loss_r, prog_bar=True)
        self.log("train_mask_loss", loss_m, prog_bar=True)
        self.log("train_loss", loss, prog_bar=True)
        print("train_loss", loss)
        print("train_mask_loss", loss_m)
        return {"loss": loss, "loss_c": loss_c, "loss_r": loss_r, "loss_m": loss_m}
    
    
    def training_epoch_end(self, outputs):
      self.train_losses.append((torch.tensor([output["loss"] for output in outputs]).mean().item(),\
                                torch.tensor([output["loss_c"] for output in outputs]).mean().item(), \
                                torch.tensor([output["loss_r"] for output in outputs]).mean().item(), \
                                torch.tensor([output["loss_m"] for output in outputs]).mean().item()))

    def validation_step(self, batch, batch_idx):
      images, labels, masks, bboxes, indices = batch
      
      images = images.to(device)
      backout = backbone(images)

      # The RPN implementation takes as first argument the following image list
      im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
      # Then we pass the image list and the backbone output through the rpn
      rpnout = rpn(im_lis, backout)
      proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
      # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
      fpn_feat_list= list(backout.values())

      feature_vectors = self.MultiScaleRoiAlign(fpn_feat_list, proposals)

    #   labels, regressor_targets = model.create_ground_truth(proposals, labels, bboxes)
      # labels, regressor_target = labels[:keep_topK], regressor_target[:keep_topK]
      class_logits, box_preds = self.forward(feature_vectors)

      '''
      proposals: x1
      '''

      labels = torch.tensor(labels)
      loss, loss_c, loss_r = self.compute_loss(class_logits, box_preds, labels, regressor_targets,l=10,effective_batch=150)

      boxes = []
      # scores = []
      # labels = []
      for i, proposal in enumerate(proposals):
        class_logit = class_logits[i*keep_topK: i*keep_topK + keep_topK]
        box_pred = box_preds[i*keep_topK: i*keep_topK + keep_topK]
        score, box, _ = self.postprocess_detections(images.squeeze(0), class_logits = class_logit, box_regression = box_pred, proposals = proposal, conf_thresh = 0.8, keep_num_preNMS=500, keep_num_postNMS=100, plot_=False)
        
        boxes.append(box)

      if boxes[0] is not None:
          feature_vectors_2 = self.MaskMultiScaleRoiAlign(fpn_feat_list, boxes)

          mask_pred =  self.forward_mask(feature_vectors_2)
          if mask_pred.shape[0] < masks[0].shape[0]:
            masks[0] = masks[0][:mask_pred.shape[0]]
          loss_m = self.loss_bce(mask_pred, masks[0])
          
          loss = loss + loss_m

          self.log("train_class_loss", loss_c, prog_bar=True)
          self.log("train_regr_loss", loss_r, prog_bar=True)
          self.log("train_mask_loss", loss_m, prog_bar=True)
          self.log("train_loss", loss, prog_bar=True)

          return {"loss": loss, "loss_c": loss_c, "loss_r": loss_r, "loss_m": loss_m}
    
    def validation_epoch_end(self, outputs):
      self.validation_losses.append((torch.tensor([output["loss"] for output in outputs]).mean().item(),\
                                torch.tensor([output["loss_c"] for output in outputs]).mean().item(), \
                                torch.tensor([output["loss_r"] for output in outputs]).mean().item(), \
                                torch.tensor([output["loss_m"] for output in outputs]).mean().item()))
    
    def configure_optimizers(self):
      optimizer = torch.optim.SGD(self.parameters(), lr=0.001, weight_decay = 5e-4)
      scheduler = {"scheduler": torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30], gamma=0.1)}
      return {"optimizer": optimizer, 'lr_scheduler': scheduler}


    # This function does the post processing for the results of the Box Head for a batch of images
    # Use the proposals to distinguish the outputs from each image
    # Input:
    #       class_logits: (total_proposals,(C+1))
    #       box_regression: (total_proposal,4*C)           ([t_x,t_y,t_w,t_h] format)
    #       proposals: list:len(bz)(per_image_proposals,4) (the proposals are produced from RPN [x1,y1,x2,y2] format)
    #       conf_thresh: scalar
    #       keep_num_preNMS: scalar (number of boxes to keep pre NMS)
    #       keep_num_postNMS: scalar (number of boxes to keep post NMS)
    # Output:
    #       boxes: list:len(bz){(post_NMS_boxes_per_image,4)}  ([x1,y1,x2,y2] format)
    #       scores: list:len(bz){(post_NMS_boxes_per_image)}   ( the score for the top class for the regressed box)
    #       labels: list:len(bz){(post_NMS_boxes_per_image)}   (top class of each regressed box)
    # def postprocess_detections(self, decoded_image, class_logits, box_regression, proposals, conf_thresh=0.5, keep_num_preNMS=500, keep_num_postNMS=50, plot_ = True):

    #   class_scores, class_labels = torch.max(class_logits, dim =1)
    #   valid_indices = class_labels > 0
    #   # print(proposals.shape)
    #   # print(valid_indices.shape)
    #   proposals = proposals[valid_indices]
    #   box_regression = box_regression[valid_indices]
    #   class_labels = class_labels[valid_indices]
    #   class_scores = class_scores[valid_indices]
      

    #   boxes = torch.zeros((len(class_labels), 4))
    #   for i in range(len(class_labels)):
    #     boxes[i,:] = box_regression[i, (class_labels[i]-1)*4:class_labels[i]*4 ]
    #   boxes = flattened_output_decoding(boxes, proposals)

    #   boxes[boxes[:,2] > 1088, 2] = 1088
    #   boxes[boxes[:,3] > 800, 3] = 800
    #   boxes[boxes[:,0] < 0, 0] = 0
    #   boxes[boxes[:,1] < 0, 1] = 0


    #   sorted_scores, sorted_indices = torch.sort(class_scores, descending = True) 

    #   keep_num_preNMS = min(keep_num_preNMS, len(sorted_scores))
    #   pre_nms_indices = sorted_indices[0:keep_num_preNMS]
    #   pre_nms_scores = sorted_scores[0:keep_num_preNMS] 
    #   boxes = boxes[pre_nms_indices,:]
    #   class_labels = class_labels[pre_nms_indices]

      
    #   boxes = boxes.detach().numpy()

    #   if plot_:
    #     label_to_color = {}
    #     label_to_color[1] = 'r'
    #     label_to_color[2] = 'g'
    #     label_to_color[3] = 'b'
    #     gt_label_color = 'c'

    #     label_to_class = {1: 'Vehicle', 2: 'Person', 3: 'Animal'}
    #     label_to_color_name = {1: 'Red', 2: 'Green', 3: 'Blue'}
    #     if keep_num_postNMS > 0:
    #       print('------------------- Predicted Boxes pre-NMS ----------------')
          
    #     figure, axis = plt.subplots(1, 3)
    #     figure.set_size_inches(18.5, 10.5)

    #     axis[0].imshow(decoded_image.permute(1, 2, 0))
    #     axis[0].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[1] + ' = ' +label_to_class[1] + ' box.') 
    #     axis[1].imshow(decoded_image.permute(1, 2, 0))
    #     axis[1].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[2] + ' = ' +label_to_class[2] + ' box.') 
    #     axis[2].imshow(decoded_image.permute(1, 2, 0))
    #     axis[2].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[3] + ' = ' +label_to_class[3] + ' box.') 
        
    #     proposals = boxes
    #     for i, label in enumerate(class_labels):
            
    #         # proposal box
    #         proposal = proposals[i, :]
    #         box = Rectangle((proposal[0], proposal[1]), proposal[2] - proposal[0], proposal[3] - proposal[1],  linewidth=1, edgecolor=label_to_color[label.item()], facecolor='none')
    #         axis[int(label) - 1].add_patch(box)

    #     plt.show()

    #   if keep_num_postNMS > 0:
    #     if pre_nms_scores.shape[0] == 0:
    #       return None, None, None
    #     # print('------------------- Predicted Boxes post-NMS ----------------')
    #     # post_nms_score, post_nms_boxes, post_nms_labels  = self.NMS(pre_nms_scores, torch.tensor(boxes, device = device), torch.tensor(class_labels, device= device), conf_thresh)
    #     post_nms_score, post_nms_boxes, post_nms_labels  = self.py_soft_nms(torch.cat((torch.tensor(boxes, device = device), pre_nms_scores), dim = 1), iou_thr = conf_thresh)
    #     sorted_post_nms_scores, sorted_post_nms_indices = torch.sort(post_nms_score, descending = True) 

    #     keep_num_postNMS = min(keep_num_postNMS, len(post_nms_score))
  
    #     post_nms_indices = sorted_post_nms_indices[0:keep_num_postNMS]
    #     post_nms_scores = sorted_post_nms_scores[0:keep_num_postNMS]
        
    #     post_nms_boxes = post_nms_boxes[post_nms_indices,:]
    #     post_nms_labels = post_nms_labels[post_nms_indices]
        
    #     if plot_:
    #       boxes = post_nms_boxes.cpu().numpy()
          
    #       figure, axis = plt.subplots(1, 3)
    #       figure.set_size_inches(18.5, 10.5)

    #       axis[0].imshow(decoded_image.permute(1, 2, 0))
    #       axis[0].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[1] + ' = ' +label_to_class[1] + ' box.') 
    #       axis[1].imshow(decoded_image.permute(1, 2, 0))
    #       axis[1].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[2] + ' = ' +label_to_class[2] + ' box.') 
    #       axis[2].imshow(decoded_image.permute(1, 2, 0))
    #       axis[2].set_title('Cyan = Ground Truth Box, ' + label_to_color_name[3] + ' = ' +label_to_class[3] + ' box.') 
          
    #       proposals = boxes
     
    #       for i, label in enumerate(post_nms_labels):
 
    #         if label != 0:
              
    #           # proposal box
    #           proposal = proposals[i, :]
    
    #           box = Rectangle((proposal[0], proposal[1]), proposal[2] - proposal[0], proposal[3] - proposal[1],  linewidth=1, edgecolor=label_to_color[label.item()], facecolor='none')
    #           axis[int(label) - 1].add_patch(box)

    #       plt.show()
    #       print('\n')
     

    #     return post_nms_scores, post_nms_boxes, post_nms_labels
    #   else:
    #     return 0,0,0


    def postprocess_detections(self, class_logits, box_regression, proposals, conf_thresh=0.0, keep_num_preNMS=500, keep_num_postNMS=50):
        class_logits = class_logits.cpu()
        box_regression = box_regression.cpu()
      
        class_scores, class_idx = torch.max(class_logits, dim =1)

        background = class_idx==0
        class_scores[background] = 0

        class_idx_cp = class_idx.clone() - 1
        class_idx_cp[class_idx == 0] = 0

        # print("class scores", class_scores)
        # print("class scores", class_idx)

        cols_to_idx           =     np.linspace(4*class_idx_cp,4*class_idx_cp+3,4).T
        rows                  =     np.arange(box_regression.shape[0]).reshape(-1,1)

        boxes_regr        = box_regression[rows, cols_to_idx]

        boxes  = []
        scores = []
        labels = []
        iou_thresh = 0.5

        j=0
        for i in range(len(proposals)):
          proposals_per_image = proposals[i].shape[0]
          unsorted_boxes_coded = boxes_regr[j:j+proposals_per_image]
          unsorted_boxes = output_decodingd(unsorted_boxes_coded,proposals[i].cpu(), device='cpu')
          
          unsorted_scores = class_scores[j:j+proposals_per_image]
          unsorted_labels = class_idx[j:j+proposals_per_image]

          #Removing out of bound boxes
          out_of_range = torch.logical_or(unsorted_boxes[:,0]<0 , unsorted_boxes[:,1]<0)
          out_of_range = torch.logical_or(out_of_range,torch.logical_or(unsorted_boxes[:,2]>1088, unsorted_boxes[:,3]>800))
          unsorted_scores[out_of_range] = 0

          sorted_scores, sorted_score_idx = torch.sort(unsorted_scores, descending = True)
          
          # idx = sorted_scores>0
          # sorted_scores = sorted_scores[idx]
          # sorted_score_idx = sorted_score_idx[idx]
          # print('sorted_scores', sorted_scores.shape)
          if (sorted_scores.shape[0]>keep_num_preNMS):
              sorted_scores = sorted_scores[:keep_num_preNMS]
              sorted_score_idx = sorted_score_idx[:keep_num_preNMS]
          
          pre_nms_scores = sorted_scores
          pre_nms_boxes = unsorted_boxes[sorted_score_idx]
          pre_nms_labels = unsorted_labels[sorted_score_idx]

          ious = iou(pre_nms_boxes.to(self.device), pre_nms_boxes.to(self.device)).triu(diagonal = 1)
          
          # print("ious", ious)
          # print("ious shape", ious.shape)

          post_nms_idx    = (ious>iou_thresh).sum(dim=0) == 0
          # print("nms", post_nms_idx)
          post_nms_idx = post_nms_idx.cpu()
          post_nms_scores = pre_nms_scores[post_nms_idx]
          post_nms_boxes  = pre_nms_boxes[post_nms_idx]
          post_nms_labels = pre_nms_labels[post_nms_idx]

          if (post_nms_scores.shape[0]>keep_num_postNMS):
              boxes.append(post_nms_boxes[:keep_num_postNMS])
              scores.append(post_nms_scores[:keep_num_postNMS])
              labels.append(post_nms_labels[:keep_num_postNMS])
          else:
              boxes.append(post_nms_boxes)
              scores.append(post_nms_scores)
              labels.append(post_nms_labels)

          j= j+proposals_per_image

        # print("Len boxes", len(boxes))
        # print("Len Scores", len(scores))
        # print("Len Labels", len(labels))

        return boxes, scores, labels

    # Compute the loss of the classifier
    # Input:
    #      p_out:     (positives_on_mini_batch)  (output of the classifier for sampled anchors with positive gt labels)
    #      n_out:     (negatives_on_mini_batch) (output of the classifier for sampled anchors with negative gt labels
    def loss_class(self,p_out,n_out,p_label, n_label):


        # TODO compute classifier's loss
        
        target = torch.cat((p_label,n_label), 0).to(device)
        output = torch.cat((p_out, n_out), 0)
      

        loss_classifier = torch.nn.CrossEntropyLoss()
        loss = loss_classifier(output, target)

        return loss

    # Compute the loss of the regressor
    # Input:
    #       pos_target_coord: (positive_on_mini_batch,4) (ground truth of the regressor for sampled anchors with positive gt labels)
    #       pos_out_r: (positive_on_mini_batch,4)        (output of the regressor for sampled anchors with positive gt labels)
    def loss_reg(self,pos_target_coord,pos_out_r):
            #torch.nn.SmoothL1Loss()
            # TODO compute regressor's loss
            # print(pos_out_r.shape)
            loss_regressor = torch.nn.SmoothL1Loss(reduction = 'sum')
            loss = loss_regressor(pos_out_r.squeeze(1),pos_target_coord)

            return loss

    def loss_bce(self, mask_pred, masks):
        # print(mask_pred)
        # print(masks)
        # print(mask_pred.shape)
        # print(masks.shape)
        l = nn.BCELoss()
        loss = l(mask_pred, masks)
        return loss


    # Compute the total loss of the classifier and the regressor
    # Input:
    #      class_logits: (total_proposals,(C+1)) (as outputed from forward, not passed from softmax so we can use CrossEntropyLoss)
    #      box_preds: (total_proposals,4*C)      (as outputed from forward)
    #      labels: (total_proposals,1)
    #      regression_targets: (total_proposals,4)
    #      l: scalar (weighting of the two losses)
    #      effective_batch: scalar
    # Outpus:
    #      loss: scalar
    #      loss_class: scalar
    #      loss_regr: scalar
    def compute_loss(self,class_logits, box_preds, labels, regression_targets,l=1,effective_batch=150):
    
        

        pos_n = (labels != 0).sum().item()
        neg_n = labels.shape[0] - pos_n


        if pos_n < (3*effective_batch/4):
          pos_idx = np.random.choice(pos_n, pos_n, replace = False)
          neg_idx = np.random.choice(neg_n, min(effective_batch - pos_n, neg_n), replace = False) 
          if(effective_batch - pos_n > neg_n):
            print('!!!ALERT!!! \n effective_batch - pos_n:', effective_batch - pos_n, ' neg_n:', neg_n, ' labels.shape[0]:', labels.shape[0])
        else:
          pos_idx = np.random.choice(pos_n, (3*effective_batch)//4, replace = False)
          neg_idx = np.random.choice(neg_n, min(effective_batch//4, neg_n), replace = False)
          if(effective_batch//4 > neg_n):
            print('!!!ALERT!!! \n effective_batch//4:', effective_batch//4, ' neg_n:', neg_n, ' labels.shape[0]:', labels.shape[0])

        p_out = class_logits[(labels != 0), :][pos_idx, :]
    
        n_out = class_logits[(labels == 0), :][neg_idx, :]
        p_label = labels[ (labels != 0)][pos_idx]
        n_label = labels[ (labels == 0)][neg_idx]
        loss_c = self.loss_class(p_out, n_out, p_label, n_label)


        regr_pos = box_preds[(labels != 0) ,:][pos_idx,:]
        rows = np.arange(regr_pos.shape[0]).reshape(-1,1)

        box_pred = regr_pos[rows,:4]

        targ_regr_pos = regression_targets[(labels != 0) ,:][pos_idx,:]
        loss_r = self.loss_reg(targ_regr_pos, box_pred.squeeze(0))/effective_batch

        loss = loss_c + l*loss_r
        
        return loss, loss_c, loss_r




    def NMS(self,clas,prebox, labels, thresh):
    ##################################
    # TODO perform NSM
    ##################################
    
      method= 'gauss'
      n = len(clas)
      ious_thresh = 0.5

      # x1 = prebox[:,0] - (prebox[:,2]/2).reshape(-1,1)
      # y1 = prebox[:,1] - (prebox[:,3]/2).reshape(-1,1)
      # x2 = prebox[:,0] + (prebox[:,2]/2).reshape(-1,1)
      # y2 = prebox[:,1] + (prebox[:,3]/2).reshape(-1,1)

      x1 = prebox[:,0]
      y1 = prebox[:,1]
      x2 = prebox[:,2]
      y2 = prebox[:,3]

      x1_ = torch.max(x1,x1.T)
      y1_ = torch.max(y1,y1.T)
      x2_ = torch.min(x2,x2.T)
      y2_ = torch.min(y2,y2.T)

      intersection = (x2_ - x1_ + 1).clamp(0) * (y2_ - y1_ + 1).clamp(0)

      union = (x2 - x1 + 1) * (y2 - y1 + 1) + (x2.T - x1.T + 1) * (y2.T - y1.T + 1) - intersection

      ious = (intersection / union).triu(diagonal=1)

      post_nms_idx    = (ious>iou_thresh).sum(dim=0) == 0
          # print("nms", post_nms_idx)
      post_nms_scores = pre_nms_scores[post_nms_idx]
      post_nms_boxes  = pre_nms_boxes[post_nms_idx]
      post_nms_labels = pre_nms_labels[post_nms_idx]

      if (post_nms_scores.shape[0]>keep_num_postNMS):
          boxes.append(post_nms_boxes[:keep_num_postNMS])
          scores.append(post_nms_scores[:keep_num_postNMS])
          labels.append(post_nms_labels[:keep_num_postNMS])
      else:
          boxes.append(post_nms_boxes)
          scores.append(post_nms_scores)
          labels.append(post_nms_labels)


      return scores, boxes, labels


    def py_soft_nms(dets, method=None, iou_thr=0.5, sigma=0.5, score_thr=0.001):
        """Pure python implementation of soft NMS as described in the paper
        `Improving Object Detection With One Line of Code`_.
        Args:
            dets (numpy.array): Detection results with shape `(num, 5)`,
                data in second dimension are [x1, y1, x2, y2, score] respectively.
            method (str): Rescore method. Only can be `linear`, `gaussian`
                or 'greedy'.
            iou_thr (float): IOU threshold. Only work when method is `linear`
                or 'greedy'.
            sigma (float): Gaussian function parameter. Only work when method
                is `gaussian`.
            score_thr (float): Boxes that score less than the.
        Returns:
            numpy.array: Retained boxes.
        .. _`Improving Object Detection With One Line of Code`:
            https://arxiv.org/abs/1704.04503
        """
        if method not in ('linear', 'gaussian', 'greedy'):
            raise ValueError('method must be linear, gaussian or greedy')

        x1 = dets[:, 0]
        y1 = dets[:, 1]
        x2 = dets[:, 2]
        y2 = dets[:, 3]

        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        # expand dets with areas, and the second dimension is
        # x1, y1, x2, y2, score, area
        dets = torch.cat((dets, areas[:, None]), dim=1)

        retained_box = []
        while dets.size > 0:
            max_idx = torch.argmax(dets[:, 4], axis=0)
            dets[[0, max_idx], :] = dets[[max_idx, 0], :]
            retained_box.append(dets[0, :-1])

            xx1 = torch.maximum(dets[0, 0], dets[1:, 0])
            yy1 = torch.maximum(dets[0, 1], dets[1:, 1])
            xx2 = torch.minimum(dets[0, 2], dets[1:, 2])
            yy2 = torch.minimum(dets[0, 3], dets[1:, 3])

            w = torch.maximum(xx2 - xx1 + 1, 0.0)
            h = torch.maximum(yy2 - yy1 + 1, 0.0)
            inter = w * h
            iou = inter / (dets[0, 5] + dets[1:, 5] - inter)

            if method == 'linear':
                weight = torch.ones_like(iou)
                weight[iou > iou_thr] -= iou[iou > iou_thr]
            elif method == 'gaussian':
                weight = np.exp(-(iou * iou) / sigma)
            else:  # traditional nms
                weight = np.ones_like(iou)
                weight[iou > iou_thr] = 0

            dets[1:, 4] *= weight
            retained_idx = torch.where(dets[1:, 4] >= score_thr)[0]
            dets = dets[retained_idx + 1, :]

        return torch.vstack(retained_box)

    def MaskMultiScaleRoiAlign(self, fpn_feat_list,proposals,P=14):

          image_size = (800, 1088) # (h, w)
          feature_vectors = []
          # print('proposals:', proposals)

          num_proposals = len(proposals)
          # print(num_proposals)
          for i in range(num_proposals):
              cur_proposal = proposals[i]
              feature_vectors_ = []
              for j in range(cur_proposal.shape[0]):
                width = cur_proposal[j][2] - cur_proposal[j][0]
                height = cur_proposal[j][3] - cur_proposal[j][1]
                K = torch.clip(torch.floor( 4 + torch.log2(torch.sqrt(width*height)/224)),2,5).int()
            
                box = cur_proposal[j].clone()
                scale_x = image_size[1]/fpn_feat_list[K-2].shape[3]
                scale_y = image_size[0]/fpn_feat_list[K-2].shape[2]

                box = box.reshape(1,-1)
                box[:,0] = box[:,0] / scale_x
                box[:,1] = box[:,1] / scale_y 
                box[:,2] = box[:,2] / scale_x
                box[:,3] = box[:,3] / scale_y
              
                feature_map = fpn_feat_list[K-2][i].unsqueeze(0) 
                aligned_box_map  = torchvision.ops.roi_align(feature_map, [box], output_size=P, spatial_scale=1, sampling_ratio=-1)
                # aligned_box_map = aligned_box_map.flatten()
                # print(aligned_box_map.shape)
                feature_vectors_.append(aligned_box_map)
          
              feature_vectors_ = torch.cat(feature_vectors_, dim=0)
              # print('feature_vectors_.shape: ',feature_vectors_.shape)
              # print(feature_vectors_.unsqueeze(0).shape)
              feature_vectors.append(feature_vectors_.unsqueeze(0))

          feature_vectors = torch.cat(feature_vectors, dim=0)
          # print('feature_vectors.shape:', feature_vectors.shape)
          return feature_vectors

    def postprocess_detections_mask(self, decoded_image, class_logits, box_regression, proposals, conf_thresh=0.8, keep_num_preNMS=100, keep_num_postNMS=50, plot_ = False):

      

      class_scores, class_labels = torch.max(class_logits, dim =1)

      valid_indices = class_labels > 0

      proposals = proposals[valid_indices]
      box_regression = box_regression[valid_indices]
      class_labels = class_labels[valid_indices]
      class_scores = class_scores[valid_indices]
      

      boxes = torch.zeros((len(class_labels), 4))
      for i in range(len(class_labels)):
        boxes[i,:] = box_regression[i, (class_labels[i]-1)*4:class_labels[i]*4 ]
      boxes = flattened_output_decoding(boxes, proposals)

      boxes[boxes[:,2] > 1088, 2] = 1088
      boxes[boxes[:,3] > 800, 3] = 800
      boxes[boxes[:,0] < 0, 0] = 0
      boxes[boxes[:,1] < 0, 1] = 0


      # print(class_scores.shape)
      sorted_scores, sorted_indices = torch.sort(class_scores, descending = True) 

      keep_num_preNMS = min(keep_num_preNMS, len(sorted_scores))
      pre_nms_indices = sorted_indices[0:keep_num_preNMS]
      pre_nms_scores = sorted_scores[0:keep_num_preNMS] 
      boxes = boxes[pre_nms_indices,:]
      class_labels = class_labels[pre_nms_indices]
      # print(boxes.shape)
      # print(pre_nms_scores.shape)
      boxes = boxes.to(device)

      inds = torchvision.ops.nms(boxes, pre_nms_scores, 0.5)

      boxes = boxes[inds]
      # print(boxes.shape)
      
      # boxes = boxes.detach().numpy()
      return boxes, pre_nms_scores, class_labels

###Box Head Training

In [None]:
model = BoxHead(Classes=3,P=7)
model.to(device)
keep_topK=500
trainer = pl.Trainer(max_epochs=num_epochs, gpus=1, num_sanity_val_steps=0)
trainer.fit(model, COCODataLoader)
torch.save(model.state_dict(), 'boxhead_new.pth')
np.save('train_losses.npy', model.train_losses)
np.save('val_losses.npy', model.validation_losses)

In [None]:
!cp -r boxhead_new.pth /content/drive/MyDrive/Mask/

###Caching rpn and boxhead


In [None]:
# Put the path were you save the given pretrained model
pretrained_path='/content/drive/MyDrive/Pretrained_Models/checkpoint680.pth'
# pretrained_path='/content/checkpoint680.pth'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
backbone, rpn = pretrained_models_680(pretrained_path)
backbone = backbone.to(device).eval()
rpn = rpn.to(device).eval()

# we will need the ImageList from torchvision
from torchvision.models.detection.image_list import ImageList


# imgs_path = '/content/drive/MyDrive/Mask/COCOdataset2017/NewData5/finalTrainImages.h5'
# masks_path = '/content/drive/MyDrive/Mask/COCOdataset2017/NewData5/finalTrainMasks.h5'
# labels_path = '/content/drive/MyDrive/Mask/COCOdataset2017/NewData5/finalTrainLabels.npy'
# bboxes_path = '/content/drive/MyDrive/Mask/COCOdataset2017/NewData5/finalTrainBboxes.npy'

imgs_path = '/content/drive/MyDrive/Mask/data_rusk/hw3_mycocodata_img_comp_zlib.h5'
masks_path = '/content/drive/MyDrive/Mask/data_rusk/hw3_mycocodata_mask_comp_zlib.h5'
labels_path = '/content/drive/MyDrive/Mask/data_rusk/hw3_mycocodata_labels_comp_zlib.npy'
bboxes_path = '/content/drive/MyDrive/Mask/data_rusk/hw3_mycocodata_bboxes_comp_zlib.npy'

paths = [imgs_path, masks_path, labels_path, bboxes_path]

batch_size = 1
num_epochs = 36

image_transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Resize((800, 1066)),
                            transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
                            transforms.Pad([11,0])
                            ])

mask_transform = transforms.Compose([       
                          transforms.Resize((800, 1066)),
                          transforms.Pad([11,0])
                          ])

COCOdata = BuildDataset(paths, image_transform, mask_transform)

COCODataLoader = BuildDataLoader(COCOdata, batch_size=batch_size)

COCODataLoader.setup()
test_loader = COCODataLoader.predict_dataloader() 
train_loader = COCODataLoader.train_dataloader()
val_loader = COCODataLoader.val_dataloader()

  label_based_masks = np.array(label_based_masks)


In [None]:
boxhead_weights_path = '/content/boxhead.pth'

boxhead = BoxHead(eval_ = True)
boxhead.to(device)
box_weights = torch.load(boxhead_weights_path)
boxhead.load_state_dict(box_weights)
boxhead.eval()

BoxHead(
  (intermediate_layer): Sequential(
    (0): Linear(in_features=12544, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): ReLU()
  )
  (classifier_head): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
  )
  (regressor_head): Sequential(
    (0): Linear(in_features=1024, out_features=12, bias=True)
  )
)

In [None]:
keep_topK = 500
confidence_scores_path = '/content/scores/'
box_preds_path = '/content/boxes/'
proposal_path = '/content/proposals/'
for j, batch in enumerate(test_loader):

  images, labels, masks, bboxes, indices = batch
  # print(images.shape)
  images = images.to(device)
  backout = backbone(images)

  # The RPN implementation takes as first argument the following image list
  im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
  # Then we pass the image list and the backbone output through the rpn
  rpnout = rpn(im_lis, backout)
  proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]

  # print(len(proposals))
  # print(proposals[0])
  # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
  fpn_feat_list= list(backout.values())

  feature_vectors = boxhead.MultiScaleRoiAlign(fpn_feat_list, proposals)

  labels, regressor_targets = boxhead.create_ground_truth(proposals, labels, bboxes)
  
  class_logits, box_preds = boxhead.forward(feature_vectors)
  # print(class_logits.shape)

  for i, idx in enumerate(indices):
    torch.save(class_logits[i*500:(i+1)*500, :], confidence_scores_path + 'cls_logit_' + str(idx) + '.pt')
    torch.save(box_preds[i*500:(i+1)*500, :], box_preds_path + 'box_pred_' + str(idx) + '.pt')
    torch.save(proposals[i], proposal_path + 'prop_pred_' + str(idx) + '.pt')





### MaskHead Utils

In [None]:
!unzip /content/scores.zip
!unzip /content/boxes.zip
!unzip /content/proposals.zip

Archive:  /content/scores.zip
   creating: content/scores/
  inflating: content/scores/box_pred_1697.pt  
  inflating: content/scores/box_pred_40.pt  
  inflating: content/scores/box_pred_2739.pt  
  inflating: content/scores/box_pred_1197.pt  
  inflating: content/scores/box_pred_2946.pt  
  inflating: content/scores/box_pred_731.pt  
  inflating: content/scores/box_pred_360.pt  
  inflating: content/scores/box_pred_2816.pt  
  inflating: content/scores/box_pred_2149.pt  
  inflating: content/scores/box_pred_2547.pt  
  inflating: content/scores/box_pred_1950.pt  
  inflating: content/scores/box_pred_2831.pt  
  inflating: content/scores/box_pred_1057.pt  
  inflating: content/scores/box_pred_1363.pt  
  inflating: content/scores/box_pred_2788.pt  
  inflating: content/scores/box_pred_3006.pt  
  inflating: content/scores/box_pred_2423.pt  
  inflating: content/scores/box_pred_113.pt  
  inflating: content/scores/box_pred_1498.pt  
  inflating: content/scores/box_pred_1853.pt  
  infl

In [None]:
batch_size = 16
def loss_bce(mask_pred, masks):
    
    l = nn.BCELoss()

    
    loss = l(mask_pred, masks)
     
    # loss = 0.9*loss1 + 1.1*loss2 + 0.8*loss3
    return loss

In [None]:
# confidence_scores_path = '/content/drive/MyDrive/Mask/boxes/'
# box_preds_path = '/content/drive/MyDrive/Mask/scores/'
# proposal_path = '/content/drive/MyDrive/Mask/proposals/'
confidence_scores_path = '/content/content/scores/'
box_preds_path = '/content/content/boxes/'
proposal_path = '/content/content/proposals/'

###Focal Loss (Extra)

In [None]:
class FocalLoss(nn.Module):
  def __init__(self, gamma=3.0, alpha=None, size_average=True):
      super(FocalLoss, self).__init__()
      self.gamma = gamma
      self.alpha = alpha
      if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
      if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
      self.size_average = size_average

  def forward(self, input, target):
      if input.dim()>2:
          input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
          input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
          input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
      target = target.permute((0,2,3,1))
      target = target.reshape(-1, target.shape[3])
      target = target.type(torch.int64)
      # print(target.shape)
      # print(input.shape)
      logpt = F.log_softmax(input)
      logpt = logpt.gather(1,target)
      logpt = logpt.view(-1)
      pt = Variable(logpt.data.exp())

      if self.alpha is not None:
          if self.alpha.type()!=input.data.type():
              self.alpha = self.alpha.type_as(input.data)
          at = self.alpha.gather(0,target.data.view(-1))
          logpt = logpt * Variable(at)

      loss = -1 * (1-pt)**self.gamma * logpt
      if self.size_average: return loss.mean()
      else: return loss.sum()

In [None]:
criterion = FocalLoss()

###MaskHead Definition

In [None]:
class MaskHead(pl.LightningModule):
  def __init__(self):
    super(MaskHead, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256))
    self.conv2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256))
    self.conv3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256))
    self.conv4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256))
    self.deconv = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) ####
    self.conv6 = nn.Conv2d(256, 3, kernel_size=1)
    self.sigmoid = nn.Sigmoid()
    self.relu = nn.ReLU()

    
    self.train_loss = []
    self.val_loss = []
  

  def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.relu(x)
    x = self.conv4(x)
    x = self.relu(x)
    x = self.deconv(x)
    x = self.relu(x)
    x = self.conv6(x)
    x = self.sigmoid(x) ### Shivani clarify 

    return x
  
  # This function for each proposal finds the appropriate feature map to sample and using RoIAlign it samples
  # a (256,P,P) feature map. This feature map is then flattened into a (256*P*P) vector
  # Input:
  #      fpn_feat_list: list:len(FPN){(bz,256,H_feat,W_feat)}
  #      proposals: list:len(bz){(per_image_proposals,4)} ([x1,y1,x2,y2] format)
  #      P: scalar
  # Output:
  #      feature_vectors: (total_proposals, 256*P*P)  (make sure the ordering of the proposals are the same as the ground truth creation) 
  def MultiScaleRoiAlign(self, fpn_feat_list,proposals,P=14):

        image_size = (800, 1088) # (h, w)
        feature_vectors = []

        num_proposals = len(proposals)
        for i in range(num_proposals):
            cur_proposal = proposals[i]
            feature_vectors_ = []
            for j in range(cur_proposal.shape[0]):
              width = cur_proposal[j][2] - cur_proposal[j][0]
              height = cur_proposal[j][3] - cur_proposal[j][1]
              K = torch.clip(torch.floor( 4 + torch.log2(torch.sqrt(width*height)/224)),2,5).int()
          
              box = cur_proposal[j].clone()
              scale_x = image_size[1]/fpn_feat_list[K-2].shape[3]
              scale_y = image_size[0]/fpn_feat_list[K-2].shape[2]

              box = box.reshape(1,-1)
              box[:,0] = box[:,0] / scale_x
              box[:,1] = box[:,1] / scale_y 
              box[:,2] = box[:,2] / scale_x
              box[:,3] = box[:,3] / scale_y
            
              feature_map = fpn_feat_list[K-2][i].unsqueeze(0) 
              aligned_box_map  = torchvision.ops.roi_align(feature_map, [box], output_size=P, spatial_scale=1, sampling_ratio=-1)
              # aligned_box_map = aligned_box_map.flatten()
              # print(aligned_box_map.shape)
              feature_vectors_.append(aligned_box_map)
        
            feature_vectors_ = torch.cat(feature_vectors_, dim=0)
            # print('feature_vectors_.shape: ',feature_vectors_.shape)
            # print(feature_vectors_.unsqueeze(0).shape)
            feature_vectors.append(feature_vectors_.unsqueeze(0))

        feature_vectors = torch.cat(feature_vectors, dim=0)
        # print('feature_vectors.shape:', feature_vectors.shape)
        return feature_vectors
      

  def training_step(self, batch, batch_idx):
      images, labels, masks, bboxes, indices = batch
      images = images.to(device)
      backout = backbone(images)
      masks = list(masks)
      # print(len(masks))
      
      # # # The RPN implementation takes as first argument the following image list
      # im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
      # # # Then we pass the image list and the backbone output through the rpn
      # rpnout = rpn(im_lis, backout)
      # # # print(rpnout[0])
      # proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
      # # # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
      fpn_feat_list= list(backout.values())

      # # # print(fpn_feat_list.)
      # feature_vectors = boxhead.MultiScaleRoiAlign(fpn_feat_list, proposals)
      # # # print(feature_vectors.shape)

      # # _, regressor_targets = boxhead.create_ground_truth(proposals, labels, bboxes)
      # # # labels, regressor_target = labels[:keep_topK], regressor_target[:keep_topK]
      # class_logits, box_preds = boxhead.forward(feature_vectors)
      running_loss = 0
      for idx_num, idx in enumerate(indices):
        proposals = [torch.load(proposal_path + 'prop_pred_' + str(idx) + '.pt')]
        class_logits = torch.load(confidence_scores_path + 'cls_logit_' + str(idx) + '.pt')
        box_preds = torch.load(box_preds_path + 'box_pred_' + str(idx) + '.pt')
        # print(box_preds.shape)
        boxes = []
        # scores = []
        # labels = []
        # for i, proposal in enumerate(proposals):
        #   class_logit = class_logits[i*keep_topK: i*keep_topK + keep_topK]
        #   box_pred = box_preds[i*keep_topK: i*keep_topK + keep_topK]
        #   box, score, _ = self.postprocess_detections(images.squeeze(0), class_logits = class_logit, box_regression = box_pred, proposals = proposal, conf_thresh = 0.8, keep_num_preNMS=500, keep_num_postNMS=100)
          
        #   boxes.append(box)

        # print(len(box))
        # print(box_preds)
        boxes, scores, labels = boxhead.postprocess_detections(class_logits, box_preds, proposals, keep_num_postNMS=50)
        boxes = [box.to(device) for box in boxes]
        feature_vectors_2 = self.MultiScaleRoiAlign(fpn_feat_list, boxes) 

        fv2_shape = feature_vectors_2.shape
        # print(fv2_shape)
        # feature_vectors_2 = feature_vectors_2.reshape((fv2_shape[0]*fv2_shape[1], fv2_shape[2], fv2_shape[3], fv2_shape[4]))
        mask_pred = self.forward(feature_vectors_2.squeeze(0))
        

        mp_shape = mask_pred.shape
        # mask_pred = mask_pred.reshape((batch_size, int(mp_shape[0]/batch_size), mp_shape[1], mp_shape[2], mp_shape[3]))
        if mask_pred.shape[0] < masks[idx_num].shape[0]:                # result of keep might be <100
              mask_target = mask_target[:masks[idx_num].shape[0]]
        ind = torch.tensor(np.indices(labels[0].unsqueeze(1).shape))
        # print(labels[0])
        ind[-1] = labels[0].unsqueeze(1) - 1
        mask_pred = mask_pred[tuple(ind)]
        masks_ = masks[idx_num][tuple(ind)]
        # print(mask_pred.shape)
        idx12 = labels[0].unsqueeze(1) > 0
        mask_pred = mask_pred[idx12]
        # print(mask_pred.shape)
        # print(idx12)
        masks_ = masks_[idx12]
        # print(masks.shape)
        # masks = 
        # print(len(masks))
        # loss_bce = nn.BCELoss()
        if mask_pred.shape[0]==0:
          running_loss += 0
        else: 
          running_loss += loss_bce(mask_pred, masks_)
        # loss = criterion(mask_pred, masks[0])

      return {"loss": running_loss/batch_size}

  def training_epoch_end(self, outputs):
      self.train_loss.append((torch.tensor([output['loss'] for output in outputs]).mean().item()))

  def validation_step(self, batch, batch_idx):
      images, labels, masks, bboxes, indices = batch
      images = images.to(device)
      backout = backbone(images)
      masks = torch.stack(list(masks))

      # The RPN implementation takes as first argument the following image list
      # im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
      # # Then we pass the image list and the backbone output through the rpn
      # # keep_topK = 20
      # rpnout = rpn(im_lis, backout)
      # proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
      # # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
      fpn_feat_list= list(backout.values())

      # feature_vectors = boxhead.MultiScaleRoiAlign(fpn_feat_list, proposals)

      # labels, regressor_targets = boxhead.create_ground_truth(proposals, labels, bboxes)
      # # labels, regressor_target = labels[:keep_topK], regressor_target[:keep_topK]
      # class_logits, box_preds = boxhead.forward(feature_vectors)

      running_loss = 0
      for idx_num, idx in enumerate(indices):
        proposals = [torch.load(proposal_path + 'prop_pred_' + str(idx) + '.pt')]
        class_logits = torch.load(confidence_scores_path + 'cls_logit_' + str(idx) + '.pt')
        box_preds = torch.load(box_preds_path + 'box_pred_' + str(idx) + '.pt')
        
        boxes = []
        # scores = []
        # # labels = []
        # for i, proposal in enumerate(proposals):
        #   class_logit = class_logits[i*keep_topK: i*keep_topK + keep_topK]
        #   box_pred = box_preds[i*keep_topK: i*keep_topK + keep_topK]
        #   box, score, _ = self.postprocess_detections(images.squeeze(0), class_logits = class_logit, box_regression = box_pred, proposals = proposal, conf_thresh = 0.8, keep_num_preNMS=500, keep_num_postNMS=100)
          
        #   boxes.append(box)
        box_preds = box_preds.to(device)
        boxes, scores, labels = boxhead.postprocess_detections(class_logits, box_preds, proposals, keep_num_postNMS=50)
        boxes = [box.to(device) for box in boxes]
        
        feature_vectors_2 = self.MultiScaleRoiAlign(fpn_feat_list, boxes)
        fv2_shape = feature_vectors_2.shape
        # feature_vectors_2 = feature_vectors_2.reshape((fv2_shape[0]*fv2_shape[1], fv2_shape[2], fv2_shape[3], fv2_shape[4]))
        mask_pred =  self.forward(feature_vectors_2.squeeze(0))
        if mask_pred.shape[0] < masks[idx_num].shape[0]:                # result of keep might be <100
              mask_target = mask_target[:masks[idx_num].shape[0]]
        mp_shape = mask_pred.shape
        ind = torch.tensor(np.indices(labels[0].unsqueeze(1).shape))
        
        ind[-1] = labels[0].unsqueeze(1) - 1
        mask_pred = mask_pred[tuple(ind)]
        masks_ = masks[idx_num][tuple(ind)]
        # print(mask_pred.shape)
        idx12 = labels[0].unsqueeze(1) > 0
        mask_pred = mask_pred[idx12]
        # print(mask_pred.shape)
        
        masks_ = masks_[idx12]
        # print(masks.shape)
        # mask_pred = mask_pred.reshape((batch_size, int(mp_shape[0]/batch_size), mp_shape[1], mp_shape[2], mp_shape[3]))
        if mask_pred.shape[0]==0:
          running_loss += 0
        else: 
          running_loss += loss_bce(mask_pred, masks_)
        # loss = criterion(mask_pred, masks[0])

      return {"loss": running_loss/batch_size}

  def validation_epoch_end(self, outputs):
      self.val_loss.append((torch.tensor([output['loss'] for output in outputs]).mean().item()))


  def configure_optimizers(self):
      # optimizer = torch.optim.SGD(self.parameters(), lr=0.001, weight_decay = 5e-4)
      optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=0.0001)
      scheduler = {"scheduler": torch.optim.lr_scheduler.MultiStepLR(optimizer, [11, 18, 25, 30], gamma=0.1)}
      # scheduler = {"scheduler": torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30], gamma=0.1)}
      return {"optimizer": optimizer, 'lr_scheduler': scheduler}

  def postprocess_detections(self, decoded_image, class_logits, box_regression, proposals, conf_thresh=0.8, keep_num_preNMS=100, keep_num_postNMS=50, plot_ = False): 

    class_scores, class_labels = torch.max(class_logits, dim =1)

    valid_indices = class_labels > 0

    proposals = proposals[valid_indices]
    box_regression = box_regression[valid_indices]
    class_labels = class_labels[valid_indices]
    class_scores = class_scores[valid_indices]
    

    boxes = torch.zeros((len(class_labels), 4))
    for i in range(len(class_labels)):
      boxes[i,:] = box_regression[i, (class_labels[i]-1)*4:class_labels[i]*4 ]
    boxes = flattened_output_decoding(boxes, proposals)

    boxes[boxes[:,2] > 1088, 2] = 1088
    boxes[boxes[:,3] > 800, 3] = 800
    boxes[boxes[:,0] < 0, 0] = 0
    boxes[boxes[:,1] < 0, 1] = 0

    keep = torch.tensor([]).to(device)
    for i in range(1, 4):
        regout_per_cls = regout_decoded[:, i]                   # [sampled_rpns, 4]
        regout_positive = torch.where(regout_per_cls[:, 0] > -1)
        boxes_per_cls = regout_per_cls[regout_positive[0]]      # [N, 4]
        scores_per_cls = cls_out[:, i]
        scores_per_cls = scores_per_cls[regout_positive[0]]

        # Apply nms for each class
        keep_idx = ops.nms(boxes_per_cls, scores_per_cls, 0.5)

        keep_boxes = boxes_per_cls[keep_idx]
        keep_scores = scores_per_cls[keep_idx]
        keep_box_score = torch.cat((keep_boxes, keep_scores.unsqueeze(-1)), dim=-1)   # [N_nms, 5]
  
        cls = torch.tensor([[i]]).type(torch.float).to(device)
        cls = cls.expand_as(keep_scores.unsqueeze(-1))
        keep_cls = torch.cat((keep_box_score, cls), dim=-1)     # [N_nms, 6] as [x1,y1,x2,y2,score,class]
        keep = torch.cat((keep, keep_cls), dim=0)
        
    # The maximum detections combined for all classes in a image
    # Find the top 100 boxes with scores
    _, idx_sorted = torch.sort(keep[:, 4], descending=True)
    keep_sorted = keep[idx_sorted]
    keep_top100 = keep_sorted[:100]
   


    # print(class_scores.shape)
    sorted_scores, sorted_indices = torch.sort(class_scores, descending = True) 

    keep_num_preNMS = min(keep_num_preNMS, len(sorted_scores))
    pre_nms_indices = sorted_indices[0:keep_num_preNMS]
    pre_nms_scores = sorted_scores[0:keep_num_preNMS] 
    boxes = boxes[pre_nms_indices,:]
    class_labels = class_labels[pre_nms_indices]
    # print(boxes.shape)
    # print(pre_nms_scores.shape)
    boxes = boxes.to(device)

    inds = torchvision.ops.nms(boxes, pre_nms_scores, 0.5)

    boxes = boxes[inds]
    # print(boxes.shape)
    
    # boxes = boxes.detach().numpy()
    return boxes, pre_nms_scores, class_labels



###Loading any trained model

In [None]:
model = MaskHead() 
# model = MaskHead.load_from_checkpoint('/content/epoch=21-step=53878.ckpt')
model.to(device)
keep_topK=500
box_weights = torch.load('/content/maskhead_bce_n_.pth')
model.load_state_dict(box_weights)
model.eval()

MaskHead(
  (conv1): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv2): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv4): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (deconv): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
  (conv6): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid): Sigmoid()
  (relu): ReLU()
)

###MaskHead Training

In [None]:
model = MaskHead()
model.to(device)
keep_topK=500
trainer = pl.Trainer(max_epochs=30, gpus=1, num_sanity_val_steps=0, default_root_dir="/content/ckpts/")
trainer.fit(model, COCODataLoader)
torch.save(model.state_dict(), 'maskhead_bce_n_.pth')

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type            | Params
--------------------------------------------
0 | conv1   | Sequential      | 590 K 
1 | conv2   | Sequential      | 590 K 
2 | conv3   | Sequential      | 590 K 
3 | conv4   | Sequential      | 590 K 
4 | deconv  | ConvTranspose2d | 262 K 
5 | conv6   | Conv2d          | 771   
6 | sigmoid | Sigmoid         | 0     
7 | relu    | ReLU            | 0     
--------------------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total param

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [None]:
model.train_loss

[0.5706157684326172,
 0.5402987003326416,
 0.5339582562446594,
 0.5321882963180542,
 0.5310494303703308,
 0.5301476120948792,
 0.5289710164070129,
 0.5294167399406433,
 0.5282508730888367,
 0.5265479683876038,
 0.526225209236145,
 0.5234083533287048,
 0.5222313404083252,
 0.5213072896003723,
 0.521565318107605,
 0.5215277075767517,
 0.520297110080719,
 0.5204072594642639,
 0.5203791856765747,
 0.5192059874534607,
 0.5194277763366699,
 0.5192421674728394,
 0.5193824768066406,
 0.5196021795272827,
 0.5187771320343018,
 0.518444836139679,
 0.5194950103759766,
 0.5193286538124084,
 0.5190631747245789,
 0.5191900730133057]

In [None]:
# train_losses = np.load("train_losses.npy")
# # validation_losses = np.load("val_losses.npy")

plt.figure()
epochs = np.arange(30)
plt.plot(epochs,  model.train_loss, label = 'Train')
plt.plot(epochs,  model.val_loss, label = 'Val')
plt.title('Loss curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
COCODataLoader.setup()
test_loader = COCODataLoader.predict_dataloader() 

###Visualising Masks and mAP calculation

In [None]:
def apply_mask(image, mask, color, alpha=0.5):
    """Apply the given mask to the image.
    """
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image
     

In [None]:
def display_instance(data, plt, scores=None, masks=None, labels=None, image=None, boxes=None):
    '''
    data: 1*{'image', 'target'}
    masks: [100, 800, 1088]
    '''
    if data:
        image = data['image'].numpy()
        image = np.transpose(image, (1, 2, 0))
        boxes = data['target']['bbox'].numpy()
        masks = data['target']['mask'].numpy() 
        labels = data['target']['label'].numpy()

    else:
        image = np.transpose(image.detach().cpu().numpy(), (1, 2, 0))
        boxes = boxes.detach().cpu().numpy()
        masks = masks.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
    
    class_names = {1: "vehicle", 2: 'people', 3: 'animal'}
    
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[0] == labels.shape[0]

    # Generate random colors
    # colors = random_colors(N)

    image_convert = (image - image.min()) * (1/(image.max() - image.min()) * 255)
    masked_image = image_convert.astype(np.uint32).copy()

    # Plot bounding boxes
    for i in range(N):
        # color = colors[i]
        # print(color)
        # Bounding box
        x1, y1, x2, y2 = boxes[i,:]
        w = x2 - x1
        h = y2 - y1

        if labels[i] == 1:    # vehicles
            # continue
            color_box = 'r'
            color = (1.0,0.0,0.0)
        elif labels[i] == 2:
            color_box = 'g'
            color = (0.0,1.0,0.0)
            # elif labels[i] == 2:  # people
            #     color_box = 'b'  
        elif labels[i] == 3:  # animals
            color_box = 'b'
            color = (0.0,0.0,1.0)

        ax = plt.gca()
        rect = patches.Rectangle((x1,y1),w,h,linewidth=2,edgecolor=color_box,facecolor='none')
        ax.add_patch(rect)

        # Label
        label = labels[i]
        score = scores[i] if scores is not None else None
        label_name = class_names[label]
        # caption = "{} {:.3f}".format(label_name, score) if score else label_name
        # ax.text(x1, y1, caption, color='w', size=12, backgroundcolor="black")
        # ax.text(x1, y1, caption,
        #         color='w', size=12, backgroundcolor="b")


        # Mask
        mask = masks[i].astype(np.uint32)
        masked_image = apply_mask(masked_image, mask, color)

        ax.imshow(masked_image.astype(np.uint8))
    
    plt.show()

In [None]:
model.to(device) 
model.eval()
masks_out = []
labels_out = []

mask_tar = []
labels_tar = []
preds = []
targets = []

for j, batch in enumerate(train_loader):
  if j == j:
    if j == j:
      
      
      images, labels_ori, masks, bboxes, indices, old_mask = batch
      labels_ori = torch.Tensor(labels_ori[0])
      # print(old_mask[0].shape)
      print(j)
      # images = (images.unsqueeze(0)).to(device)
      
      images = images.to(device)
      # images = torch.FloatTensor(images)
      # backbone = backbone.double()
      # rpn = rpn.double()
      # boxhead = boxhead.double()
      backout = backbone(images)

      # The RPN implementation takes as first argument the following image list
      im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
      # Then we pass the image list and the backbone output through the rpn
      rpnout = rpn(im_lis, backout)
      proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
      # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
      fpn_feat_list= list(backout.values())

      feature_vectors = boxhead.MultiScaleRoiAlign(fpn_feat_list, proposals)

      # labels, regressor_targets = boxhead.create_ground_truth(proposals, labels, bboxes)

      class_logits, box_preds = boxhead.forward(feature_vectors)


      MEAN = torch.tensor([0.485, 0.456, 0.406])
      STD = torch.tensor([0.229, 0.224, 0.225])

      images = images.cpu().squeeze()

      x = images * STD[:, None, None] + MEAN[:, None, None]

      

      resize = transforms.Resize((800, 1088))
      
      for idx in indices:
            # proposals = [torch.load(proposal_path + 'prop_pred_' + str(idx) + '.pt')]
            # class_logits = torch.load(confidence_scores_path + 'cls_logit_' + str(idx) + '.pt')
            # box_preds = torch.load(box_preds_path + 'box_pred_' + str(idx) + '.pt')
            # print(box_preds.shape)
            boxes = []
            # scores = []
            # labels = []
            # for i, proposal in enumerate(proposals):
            #   class_logit = class_logits[i*keep_topK: i*keep_topK + keep_topK]
            #   box_pred = box_preds[i*keep_topK: i*keep_topK + keep_topK]
            #   box, score, _ = self.postprocess_detections(images.squeeze(0), class_logits = class_logit, box_regression = box_pred, proposals = proposal, conf_thresh = 0.8, keep_num_preNMS=500, keep_num_postNMS=100)
              
            #   boxes.append(box)

            # print(len(box))
            # print(box_preds)
            boxes, scores, labels = boxhead.postprocess_detections(class_logits, box_preds, proposals, keep_num_postNMS=50)
            boxes = [box.to(device) for box in boxes]
            feature_vectors_2 = model.MultiScaleRoiAlign(fpn_feat_list, boxes) 

            fv2_shape = feature_vectors_2.shape
            # print(fv2_shape)
            # feature_vectors_2 = feature_vectors_2.reshape((fv2_shape[0]*fv2_shape[1], fv2_shape[2], fv2_shape[3], fv2_shape[4]))
            mask_pred = model.forward(feature_vectors_2.squeeze(0))
            boxes = boxes[0]
            num_plot = 1
            # print(labels[0])
            cls_lab = labels[0][num_plot-1].item()-1
            mask_one_cls = mask_pred[:,cls_lab,:,:].unsqueeze(1)
            mask_total = torch.tensor([]).to(device)
            for idx in range(mask_pred.shape[0]):

              
              x1, y1, x2, y2 = tuple(boxes[idx].detach().cpu().numpy()) # x1, x2, y1, y2
              x1, y1, x2, y2 = max(0, int(x1)), max(0, int(y1)), min(int(x2), 1088), min(int(y2), 800)
              h = y2 - y1
              w = x2 - x1
              if int(h) < 2 or int(w) < 2:
                  continue
              mask_one = mask_one_cls[idx].unsqueeze(0)               # [1, 1, 28, 28]
              mask_rescaled = F.interpolate(mask_one, size = (h, w), mode='bilinear', align_corners=True).squeeze()   # [h, w]

              # Padding
              p2d = (int(x1), 1088-int(x1)-int(w), int(y1), 800-int(y1)-int(h))
              # print(p2d)
              mask_padded = F.pad(mask_rescaled, p2d).unsqueeze(0)    # [800, 1088]
              mask_total = torch.cat((mask_total,mask_padded), dim=0)
            threshold = 0.5
            positive_idx = torch.where(mask_total >= threshold)
            negative_idx = torch.where(mask_total < threshold)
            mask_total[positive_idx] = 1
            mask_total[negative_idx] = 0
            # print(torch.where(mask_total[1]==1))
            plt.figure(j)
            # print(scores[0])  

            if labels[0][0]!=labels[0][1] and labels[0][1] != 0 and labels[0][0]!=0:
              num_plot = 2    
            # print(labels[0])
            if labels[0][0]!=0:
              pred_image_dict = {'masks':mask_total[:num_plot].detach().cpu().type(torch.uint8),
                            'labels':labels[0][:num_plot].detach().cpu(),
                            'scores':scores[0][:num_plot].detach().cpu()}
              preds.append(pred_image_dict)
              target_dict = {'masks':old_mask[0][:num_plot].detach().cpu().type(torch.uint8),
                            'labels':labels_ori[:num_plot].detach().cpu()}
              targets.append(target_dict)
            # masks_out.append(mask_total[:num_plot].detach().cpu())
            # labels_out.append(labels[0][:num_plot].detach().cpu())
            # mask_tar.append(old_mask[0][:num_plot].detach().cpu())
            # labels_tar.append(labels_ori[:num_plot].detach().cpu())
            display_instance(None, plt, scores[0][:num_plot], mask_total[:num_plot], labels[0][:num_plot], images, boxes[:num_plot])
    
          
      # break

In [None]:
def apply_mask(image, mask, color, alpha=0.5):
    """Apply the given mask to the image.
    """
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image
     

In [None]:
def display_instance(data, plt, scores=None, masks=None, labels=None, image=None, boxes=None):
    '''
    data: 1*{'image', 'target'}
    masks: [100, 800, 1088]
    '''
    if data:
        image = data['image'].numpy()
        image = np.transpose(image, (1, 2, 0))
        boxes = data['target']['bbox'].numpy()
        masks = data['target']['mask'].numpy() 
        labels = data['target']['label'].numpy()

    else:
        image = np.transpose(image.detach().cpu().numpy(), (1, 2, 0))
        boxes = boxes.detach().cpu().numpy()
        masks = masks.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
    
    class_names = {1: "vehicle", 2: 'people', 3: 'animal'}
    
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[0] == labels.shape[0]

    # Generate random colors
    # colors = random_colors(N)

    image_convert = (image - image.min()) * (1/(image.max() - image.min()) * 255)
    masked_image = image_convert.astype(np.uint32).copy()

    # Plot bounding boxes
    for i in range(N):
        # color = colors[i]
        # print(color)
        # Bounding box
        x1, y1, x2, y2 = boxes[i,:]
        w = x2 - x1
        h = y2 - y1

        if labels[i] == 1:    # vehicles
            # continue
            color_box = 'r'
            color = (1.0,0.0,0.0)
        elif labels[i] == 2:
            color_box = 'g'
            color = (0.0,1.0,0.0)
            # elif labels[i] == 2:  # people
            #     color_box = 'b'  
        elif labels[i] == 3:  # animals
            color_box = 'b'
            color = (0.0,0.0,1.0)

        ax = plt.gca()
        rect = patches.Rectangle((x1,y1),w,h,linewidth=2,edgecolor=color_box,facecolor='none')
        ax.add_patch(rect)

        # Label
        label = labels[i]
        score = scores[i] if scores is not None else None
        label_name = class_names[label]
        # caption = "{} {:.3f}".format(label_name, score) if score else label_name
        # ax.text(x1, y1, caption, color='w', size=12, backgroundcolor="black")
        # ax.text(x1, y1, caption,
        #         color='w', size=12, backgroundcolor="b")


        # Mask
        mask = masks[i].astype(np.uint32)
        masked_image = apply_mask(masked_image, mask, color)

        ax.imshow(masked_image.astype(np.uint8))
    
    plt.show()

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

# preds = [
#   dict(
#     masks=torch.tensor([[258.0, 41.0, 606.0, 285.0]]),
#     scores=torch.tensor([0.536]),
#     labels=torch.tensor([0]),
#   )
# ]
# target = [dict(boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]), labels=torch.tensor([0]),)]
metric = MeanAveragePrecision(iou_type='segm', class_metrics=True, iou_thresholds = [0.5])
metric.update(preds, targets)
from pprint import pprint
pprint(metric.compute())

{'map': tensor(0.6727),
 'map_50': tensor(0.6727),
 'map_75': tensor(-1),
 'map_large': tensor(0.6791),
 'map_medium': tensor(0.4185),
 'map_per_class': tensor([0.5912, 0.6873, 0.7396]),
 'map_small': tensor(-1.),
 'mar_1': tensor(0.7634),
 'mar_10': tensor(0.7634),
 'mar_100': tensor(0.7634),
 'mar_100_per_class': tensor([0.6910, 0.7663, 0.8329]),
 'mar_large': tensor(0.7719),
 'mar_medium': tensor(0.5093),
 'mar_small': tensor(-1.)}


###Real Time Segmentation

In [None]:
import PIL
import io

In [None]:
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import cv2
import numpy as np

def take_photo(filename='photo.jpg', quality=0.8):
  js = Javascript('''
    async function takePhoto(quality) {
      const div = document.createElement('div');
      const capture = document.createElement('button');
      capture.textContent = 'Capture';
      div.appendChild(capture);

      const video = document.createElement('video');
      video.style.display = 'block';
      const stream = await navigator.mediaDevices.getUserMedia({video: true});

      document.body.appendChild(div);
      div.appendChild(video);
      video.srcObject = stream;
      await video.play();

      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      // Wait for Capture to be clicked.
      await new Promise((resolve) => capture.onclick = resolve);

      const canvas = document.createElement('canvas');
      canvas.width = video.videoWidth;
      canvas.height = video.videoHeight;
      canvas.getContext('2d').drawImage(video, 0, 0);
      stream.getVideoTracks()[0].stop();
      div.remove();
      return canvas.toDataURL('image/jpeg', quality);
    }
    ''')
  display(js)
  data = eval_js('takePhoto({})'.format(quality))
  img = js_to_image(data)
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)
  return filename, img

In [None]:
def js_to_image(js_reply):
  
  image_bytes = b64decode(js_reply.split(',')[1])
  
  jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
  
  img = cv2.imdecode(jpg_as_np, flags=1)

  return img

In [None]:
def video_stream():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;
    
    var pendingResolve = null;
    var shutdown = false;
    
    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }
    
    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 640, 480);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }
    
    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);
      
      const modelOut = document.createElement('div');
      modelOut.innerHTML = "<span>Status:</span>";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);
           
      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      imgElement.style.position = 'absolute';
      imgElement.style.zIndex = 1;
      imgElement.onclick = () => { shutdown = true; };
      div.appendChild(imgElement);
      
      const instruction = document.createElement('div');
      instruction.innerHTML = 
          '<span style="color: red; font-weight: bold;">' +
          'When finished, click here or on the video to stop this demo</span>';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };
      
      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 640; //video.videoWidth;
      captureCanvas.height = 480; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);
      
      return stream;
    }
    async function stream_frame(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();
      
      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }
            
      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        imgElement.style.top = videoRect.top + "px";
        imgElement.style.left = videoRect.left + "px";
        imgElement.style.width = videoRect.width + "px";
        imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }
      
      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      
      return {'create': preShow - preCreate, 
              'show': preCapture - preShow, 
              'capture': Date.now() - preCapture,
              'img': result};
    }
    ''')

  display(js)
  
def video_frame(label, mask):
  data = eval_js('stream_frame("{}", "{}")'.format(label, mask))
  return data

In [None]:
keep_topK = 500

In [None]:
video_stream()
# label for video
label_html = 'Capturing...'
# initialze bounding box to empty
bbox = ''
mask = ''
count = 0 
while True:
    js_reply = video_frame(label_html, mask)
    if not js_reply:
        break

    # convert JS response to OpenCV Image
    frame = js_to_image(js_reply["img"])
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, (1088, 800),
                                  interpolation=cv2.INTER_LINEAR)
    image = img_resized.astype('uint8')
    # label = self.labels[idx].astype('float32')
    # mask = torch.tensor(self.masks[idx].astype('float64'))
    # bbox = self.bboxes[idx].astype('float32')
    bbox_array = np.zeros([480,640,4], dtype=np.uint8)
    img_height, img_width, _ = frame.shape
    width_ratio = img_width/1088
    height_ratio = img_height/800
    if image_transform:
      images = image_transform(image)

    images = (images.unsqueeze(0)).to(device)

    backout = backbone(images)

    # The RPN implementation takes as first argument the following image list
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    # Then we pass the image list and the backbone output through the rpn
    rpnout = rpn(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
    fpn_feat_list= list(backout.values())

    feature_vectors = boxhead.MultiScaleRoiAlign(fpn_feat_list, proposals)

    # labels, regressor_targets = boxhead.create_ground_truth(proposals, labels, bboxes)

    class_logits, box_preds = boxhead.forward(feature_vectors)
    # print(boxu)
    boxes, scores, labels = boxhead.postprocess_detections(class_logits, box_preds, proposals, keep_num_postNMS=50)
    num_plot = 1
    boxes_ = boxes[0][:num_plot]
    # print(boxes)
    x1, y1, x2, y2 = boxes_[:, 0], boxes_[:, 1], boxes_[:, 2], boxes_[:, 3]
    
    bbox_array = cv2.rectangle(bbox_array, (int(x1 * width_ratio), int(y1 * height_ratio)), (int(x2 * width_ratio), int(y2 * height_ratio)), (255, 0, 0), 1)
      # print(bbox_array)
    MEAN = torch.tensor([0.485, 0.456, 0.406])
    STD = torch.tensor([0.229, 0.224, 0.225])


    bbox_array[:,:,3] = (bbox_array.max(axis = 2) > 0 ).astype(int) * 255
    # convert overlay of bbox into bytes
    bbox_PIL = PIL.Image.fromarray(bbox_array, 'RGBA')
    iobuf = io.BytesIO()
    # format bbox into png for return
    bbox_PIL.save(iobuf, format='png')
    # format return string
    bbox_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))
    # bbox_bytes = bbox_to_bytes(bbox_array)
    # update bbox so next frame gets new overlay
    bbox = bbox_bytes
    boxes = [box.to(device) for box in boxes]
    feature_vectors_2 = model.MultiScaleRoiAlign(fpn_feat_list, boxes) 

    fv2_shape = feature_vectors_2.shape
    # print(fv2_shape)
    # feature_vectors_2 = feature_vectors_2.reshape((fv2_shape[0]*fv2_shape[1], fv2_shape[2], fv2_shape[3], fv2_shape[4]))
    mask_pred = model.forward(feature_vectors_2.squeeze(0))
    boxes = boxes[0]
    # num_plot = 4
    cls_lab = labels[0][num_plot-1].item()-1
    mask_one_cls = mask_pred[:,cls_lab,:,:].unsqueeze(1)
    mask_total = torch.tensor([]).to(device)
    for idx in range(mask_pred.shape[0]):

      # if int(h[idx]) < 2 or int(w[idx]) < 2:
      #     continue
      x1, y1, x2, y2 = tuple(boxes[idx].detach().cpu().numpy()) # x1, x2, y1, y2
      x1, y1, x2, y2 = max(0, int(x1)), max(0, int(y1)), min(int(x2), 1088), min(int(y2), 800)
      h = y2 - y1
      w = x2 - x1
      mask_one = mask_one_cls[idx].unsqueeze(0)               # [1, 1, 28, 28]
      mask_rescaled = F.interpolate(mask_one, size = (h, w), mode='bilinear', align_corners=True).squeeze()   # [h, w]

      # Padding
      p2d = (int(x1), 1088-int(x1)-int(w), int(y1), 800-int(y1)-int(h))
      mask_padded = F.pad(mask_rescaled, p2d).unsqueeze(0)    # [800, 1088]
      mask_total = torch.cat((mask_total,mask_padded), dim=0)
    positive_idx = torch.where(mask_total >= 0.5)
    negative_idx = torch.where(mask_total < 0.5)
    mask_total[positive_idx] = 1
    mask_total[negative_idx] = 0
    mask_transfer = transforms.Resize((480,640))
    mask_total = mask_transfer(mask_total[:num_plot].unsqueeze(1)).squeeze(1).squeeze(0)
    positive_idx2 = torch.where(mask_total >= 0.6)
    negative_idx2 = torch.where(mask_total < 0.6)
    mask_total[positive_idx2] = 1
    mask_total[negative_idx2] = 0
    mask_total = mask_total.type(torch.uint8)
    #  = apply_mask(frame, mask_total, (0,0,255))
    
    mask_array = np.zeros([480,640,4], dtype=np.uint8)
    mask_array[:,:,cls_lab] = mask_total.detach().cpu().numpy()*255
    mask_array[:,:,3] = mask_total.detach().cpu().numpy()*115
    # mask_array[:,:,3] = 1
    mask_PIL = PIL.Image.fromarray(mask_array, 'RGBA')
    iobuf = io.BytesIO()
    # format bbox into png for return
    mask_PIL.save(iobuf, format='png')
    # format return string
    mask_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))
    mask = mask_bytes
    images = images.cpu().squeeze()


    # x = images * STD[:, None, None] + MEAN[:, None, None]

    

<IPython.core.display.Javascript object>

  class_logits = softmax(class_logits)
