# Install the requirement and packages

In [None]:
import torch
import os
import os
import cv2
import numpy as np
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt
!mkdir dataset
os.chdir("./dataset")
!curl -L "https://universe.roboflow.com/ds/qsAqxl1yWz?key=cc6BA6xJi0" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip
!pip install pycocotools
!pip install gdown

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   892  100   892    0     0    468      0  0:00:01  0:00:01 --:--:--   468
100 58.9M  100 58.9M    0     0  21.1M      0  0:00:02  0:00:02 --:--:--  111M
Archive:  roboflow.zip
 extracting: README.dataset.txt      
 extracting: README.roboflow.txt     
   creating: train/
 extracting: train/0055_png.rf.5867575f7facef50e8752469026e57f4.jpg  
 extracting: train/0057_png.rf.0e5457e8188d827b04b87cef5ef73384.jpg  
 extracting: train/0058_png.rf.377f6de8073af39d79c71aa3a8de0c58.jpg  
 extracting: train/0059_png.rf.8e115e93827bf5f9bafb9ef7afe25d3c.jpg  
 extracting: train/0061_png.rf.489fef22fcc5006bc158c84bd508ad4c.jpg  
 extracting: train/0063_png.rf.b00bf79c53fb176ea0f87b4026253bcb.jpg  
 extracting: train/0065_png.rf.e93c6edf94c38c67f6429e43e3b0da7e.jpg  
 extracting: train/0066_png.rf.517748c05ee68beac0ea072cc1678819.jpg  
 extr

## Settings

In [None]:
os.chdir("../")
import warnings
warnings.filterwarnings("ignore")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 5e-4
BATCH_SIZE = 8
NUM_EPOCHS = 30
NUM_WORKERS = 4
CHECKPOINT_FILE = "Best.pth.tar"
PIN_MEMORY = True
SAVE_MODEL = False
LOAD_MODEL = False
TRAIN_DIR = './dataset/train'
VALID_DIR = './dataset/valid'
TEST_DIR = './dataset/test'
IMAGE_SIZE = [1152,648]

# Dataset

In [None]:
class CCDataset(Dataset):
  def __init__(self, mode = 'train', augmentation=None):
    if mode == 'train':
      self.dataset_path = TRAIN_DIR
      ann_path = os.path.join(TRAIN_DIR, '_annotations.coco.json')
    if mode == 'valid':
      self.dataset_path = VALID_DIR
      ann_path = os.path.join(VALID_DIR, '_annotations.coco.json')
    if mode == 'test':
      self.dataset_path = TEST_DIR
      ann_path = os.path.join(TEST_DIR, '_annotations.coco.json')
    
    self.coco = COCO(ann_path)
    self.cat_ids = self.coco.getCatIds()
    self.augmentation=augmentation

  def __len__(self):
      return len(self.coco.imgs)
  
  def get_masks(self, index):
      ann_ids = self.coco.getAnnIds([index])
      anns = self.coco.loadAnns(ann_ids)
      masks=[]

      for ann in anns:
            mask = self.coco.annToMask(ann)
            masks.append(mask)

      return masks

  def get_boxes(self, masks):
      num_objs = len(masks)
      boxes = []

      for i in range(num_objs):
          x,y,w,h = cv2.boundingRect(masks[i])
          boxes.append([x, y, x+w, y+h])

      return np.array(boxes)

  def __getitem__(self, index):
      # Load image
      img_info = self.coco.loadImgs([index])[0]
      image = cv2.imread(os.path.join(self.dataset_path,
                                    img_info['file_name']))
      masks = self.get_masks(index)

      if self.augmentation:
        augmented = self.augmentation(image=image, masks=masks)
        image, masks = augmented['image'], augmented['masks']

      image = image.transpose(2,0,1) / 255.

      # Load masks
      masks = np.array(masks)
      boxes = self.get_boxes(masks)

      # Create target dict
      num_objs = len(masks)
      boxes = torch.as_tensor(boxes, dtype=torch.float32)
      labels = torch.ones((num_objs,), dtype=torch.int64)
      masks = torch.as_tensor(masks, dtype=torch.uint8)
      image = torch.as_tensor(image, dtype=torch.float32)
      data = {}
      data["boxes"] =  boxes
      data["labels"] = labels
      data["masks"] = masks

      return image, data

In [None]:
# To bypass shape errors
# Images contain different number of instances
def collate_fn(batch):
  images = list()
  targets = list()
  for b in batch:
        images.append(b[0])
        targets.append(b[1])
  images = torch.stack(images, dim=0)
  return images, targets

## transform

In [None]:
import albumentations as A
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(
        contrast_limit=0.2, brightness_limit=0.3, p=0.5),
    A.OneOf([
        A.ImageCompression(p=0.8),
        A.RandomGamma(p=0.8),
        A.Blur(p=0.8),
        A.Equalize(mode='cv',p=0.8)
    ], p=1.0),
    A.OneOf([
        A.ImageCompression(p=0.8),
        A.RandomGamma(p=0.8),
        A.Blur(p=0.8),
        A.Equalize(mode='cv',p=0.8),
    ], p=1.0)
])

# Model

## Functions

In [None]:
import torch
from tqdm import tqdm
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [None]:
def get_model():
    model = maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(
            in_features, num_classes=1+1)
    model.to(DEVICE)
    
    return model

In [None]:
def save_checkpoint(state, filename="mask_rcnn.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [None]:
def load_checkpoint(checkpoint, model, optimizer, lr):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    #optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
def train_one_epoch(loader, model, optimizer, device):
    loop = tqdm(loader)

    for batch_idx, (images, targets) in enumerate(loop):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    

    print(f"Total loss: {losses.item()}")

In [None]:
best_vloss = np.inf
def validate(loader, model, optimizer, device, epoch):
    global best_vloss
    loop = tqdm(loader)
    running_vloss = 0
    for batch_idx, (images, targets) in enumerate(loop):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        with torch.no_grad():
          loss_dict = model(images, targets)
        
        losses = sum(loss for loss in loss_dict.values())
        running_vloss += losses
        
    avg_vloss = running_vloss / (batch_idx + 1)
    
    print(f"Avg Valid Loss: {avg_vloss}")
    if avg_vloss < best_vloss:
      best_vloss = avg_vloss
      if SAVE_MODEL:
            print("Model improved, saving...")
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint, filename=f"1152KaggleBest_second_{epoch}.pth.tar")
    print('\n')
    return avg_vloss

## Data Loader

In [None]:
train_dataset = CCDataset(mode='train', augmentation=transform)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=PIN_MEMORY,
                              collate_fn=collate_fn)

valid_dataset = CCDataset(mode='valid')
valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              pin_memory=PIN_MEMORY,
                              collate_fn=collate_fn)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


## Start training

In [None]:
model = get_model()
optimizer = torch.optim.AdamW(params=model.parameters(),
                                  lr=LEARNING_RATE,
                                  weight_decay=WEIGHT_DECAY)

if LOAD_MODEL and CHECKPOINT_FILE in os.listdir():
        print("Loading checkpoint")
        load_checkpoint(torch.load(CHECKPOINT_FILE), model, optimizer, LEARNING_RATE)
model.train()
for epoch in range(NUM_EPOCHS):
                print(f"Epoch: {epoch}")
                train_one_epoch(train_loader, model, optimizer, DEVICE)
                vloss= validate(valid_loader, model, optimizer, DEVICE, epoch)

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth


  0%|          | 0.00/170M [00:00<?, ?B/s]

Epoch: 0


100%|██████████| 38/38 [00:23<00:00,  1.65it/s]


Total loss: 0.526561975479126


100%|██████████| 10/10 [00:03<00:00,  2.64it/s]


Avg Valid Loss: 0.577857255935669


Epoch: 1


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.5471883416175842


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]


Avg Valid Loss: 0.3765439987182617


Epoch: 2


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.31462401151657104


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.3229812681674957


Epoch: 3


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.31913185119628906


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]


Avg Valid Loss: 0.28875789046287537


Epoch: 4


100%|██████████| 38/38 [00:14<00:00,  2.53it/s]


Total loss: 0.20269134640693665


100%|██████████| 10/10 [00:03<00:00,  2.67it/s]


Avg Valid Loss: 0.27000173926353455


Epoch: 5


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.28099292516708374


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]


Avg Valid Loss: 0.27243462204933167


Epoch: 6


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.2975877821445465


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]


Avg Valid Loss: 0.2585557699203491


Epoch: 7


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.1763504296541214


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.250700443983078


Epoch: 8


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.24043013155460358


100%|██████████| 10/10 [00:03<00:00,  2.64it/s]


Avg Valid Loss: 0.2495812475681305


Epoch: 9


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.19931788742542267


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.2349766045808792


Epoch: 10


100%|██████████| 38/38 [00:15<00:00,  2.49it/s]


Total loss: 0.19346323609352112


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]


Avg Valid Loss: 0.23915250599384308


Epoch: 11


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.21433904767036438


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.23462621867656708


Epoch: 12


100%|██████████| 38/38 [00:14<00:00,  2.55it/s]


Total loss: 0.22139377892017365


100%|██████████| 10/10 [00:03<00:00,  2.68it/s]


Avg Valid Loss: 0.22361384332180023


Epoch: 13


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.18270261585712433


100%|██████████| 10/10 [00:03<00:00,  2.66it/s]


Avg Valid Loss: 0.22755010426044464


Epoch: 14


100%|██████████| 38/38 [00:15<00:00,  2.49it/s]


Total loss: 0.14953789114952087


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.2213132232427597


Epoch: 15


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.17518439888954163


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]


Avg Valid Loss: 0.2222278118133545


Epoch: 16


100%|██████████| 38/38 [00:15<00:00,  2.50it/s]


Total loss: 0.15723823010921478


100%|██████████| 10/10 [00:03<00:00,  2.61it/s]


Avg Valid Loss: 0.22091026604175568


Epoch: 17


100%|██████████| 38/38 [00:15<00:00,  2.50it/s]


Total loss: 0.12685523927211761


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]


Avg Valid Loss: 0.22350111603736877


Epoch: 18


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.1683207005262375


100%|██████████| 10/10 [00:03<00:00,  2.60it/s]


Avg Valid Loss: 0.21880309283733368


Epoch: 19


100%|██████████| 38/38 [00:15<00:00,  2.50it/s]


Total loss: 0.19649185240268707


100%|██████████| 10/10 [00:03<00:00,  2.61it/s]


Avg Valid Loss: 0.21665377914905548


Epoch: 20


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.21359780430793762


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.21475115418434143


Epoch: 21


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.15560394525527954


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.21510379016399384


Epoch: 22


100%|██████████| 38/38 [00:14<00:00,  2.59it/s]


Total loss: 0.16793131828308105


100%|██████████| 10/10 [00:03<00:00,  2.82it/s]


Avg Valid Loss: 0.21316173672676086


Epoch: 23


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.14015989005565643


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]


Avg Valid Loss: 0.20668378472328186


Epoch: 24


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.14142665266990662


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.20569513738155365


Epoch: 25


100%|██████████| 38/38 [00:15<00:00,  2.50it/s]


Total loss: 0.18324725329875946


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]


Avg Valid Loss: 0.2056959718465805


Epoch: 26


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.15472009778022766


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Avg Valid Loss: 0.20643262565135956


Epoch: 27


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.13715757429599762


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]


Avg Valid Loss: 0.20111969113349915


Epoch: 28


100%|██████████| 38/38 [00:15<00:00,  2.52it/s]


Total loss: 0.1166934221982956


100%|██████████| 10/10 [00:03<00:00,  2.64it/s]


Avg Valid Loss: 0.20828230679035187


Epoch: 29


100%|██████████| 38/38 [00:15<00:00,  2.51it/s]


Total loss: 0.16076403856277466


100%|██████████| 10/10 [00:03<00:00,  2.65it/s]

Avg Valid Loss: 0.20253515243530273







# Prediction

In [None]:
!gdown https://drive.google.com/uc?id=1VZdpRVDPCjh3Ro9KbJtb9CN2lakHnUy_

Downloading...
From: https://drive.google.com/uc?id=1VZdpRVDPCjh3Ro9KbJtb9CN2lakHnUy_
To: /content/test_video.mp4
  0% 0.00/7.15M [00:00<?, ?B/s] 66% 4.72M/7.15M [00:00<00:00, 38.7MB/s]100% 7.15M/7.15M [00:00<00:00, 54.6MB/s]


In [None]:
import cv2
def predict_single_frame(frame):
    images = cv2.resize(frame, IMAGE_SIZE, cv2.INTER_LINEAR)/255
    images = torch.as_tensor(images, dtype=torch.float32).unsqueeze(0)
    images = images.swapaxes(1, 3).swapaxes(2, 3)
    images = list(image.to(DEVICE) for image in images)
    
    with torch.no_grad():
      pred = model(images)
    
    im = images[0].swapaxes(0, 2).swapaxes(0, 1).detach().cpu().numpy().astype(np.float32)
    im2 = np.zeros_like(im).astype(np.float32)
    for i in range(len(pred[0]['masks'])):
        msk=pred[0]['masks'][i,0].detach().cpu().numpy()
        scr=pred[0]['scores'][i].detach().cpu().numpy()
        box=pred[0]['boxes'][i].detach().cpu().numpy()
        
        if scr>0.9 :
            cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0,0,1), 2)
            cv2.putText(im, "{0:.2f}%".format(scr*100), (int(box[0]+5), int(box[1])+15), cv2.FONT_HERSHEY_SIMPLEX,
                        0.5, (0,0,1), 2, cv2.LINE_AA)
            im2[:,:,0][msk>0.87] = np.random.uniform(0,1)
            im2[:, :, 1][msk > 0.87] = np.random.uniform(0,1)
            im2[:, :, 2][msk > 0.87] = np.random.uniform(0,1)

    return (cv2.addWeighted(im, 0.8, im2, 0.2,0)*255).astype(np.uint8)

In [None]:
cap = cv2.VideoCapture('./test_video.mp4')
model.train(False)

if (cap.isOpened()== False): 
    print("Error opening video stream or file")

images = []   
while(cap.isOpened()):
    ret, frame = cap.read()
    if ret == True:
        result_frame = predict_single_frame(frame)
        images.append(result_frame)
    else: 
        break

cap.release()

In [None]:
import imageio
imageio.mimsave('./result.gif', images)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
imageio.mimsave('/content/drive/MyDrive/result.gif', images)

Mounted at /content/drive
