In [None]:
import os
import torch
import typing
from typing import List
import pandas as pd
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import cv2
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from __future__ import print_function

from collections import defaultdict, deque
import datetime
import pickle
import time
import torch.distributed as dist
import errno

from torch.utils.tensorboard import SummaryWriter

from torchmetrics.detection.map import MeanAveragePrecision

from pprint import pprint

from tqdm import tqdm

from src.dataset import FoodDataset
from src.vis import read_image, show_image_coco

%matplotlib inline

In [None]:
RANDOM_SEED = 42

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

TRAIN_IMAGES_PATH = 'data/public_training_set_release_2.0/images/'
TRAIN_LABELS = 'data/public_training_set_release_2.0/annotations.json'

VAL_IMAGES_PATH = 'data/public_validation_set_2.0/images/'
VAL_LABELS = 'data/public_validation_set_2.0/annotations.json'

DEVICE = 'cuda'
MODEL_SAVE_PATH = ''

In [None]:
def show_ind_img(ds: COCO, ind: int, ims_path: str) -> None:
    img_ids = ds.getImgIds()
    
    return show_image_coco(img_ids[ind], ims_path, ds, True, True)

def show_random_img(ds: COCO, ims_path: str) -> None:
    img_ids = ds.getImgIds()
    rand_ind = np.random.randint(len(img_ids))
    
    return show_image_coco(img_ids[rand_ind], ims_path, ds, True, True)

ds_coco = COCO(TRAIN_LABELS)

In [None]:
anns_obj = ds_coco.loadAnns(ds_coco.getAnnIds(131094))
anns_obj

In [None]:
dir(ds_coco)

In [None]:
ds_coco.getCatIds()
{coco_id: ind for ind, coco_id  in enumerate(sorted(ds_coco.getCatIds()))}

In [None]:
# show_random_img(ds_coco, TRAIN_IMAGES_PATH)
show_ind_img(ds_coco, 1, TRAIN_IMAGES_PATH)

In [None]:
torch_ds = FoodDataset(TRAIN_IMAGES_PATH, TRAIN_LABELS)

torch_ds

In [None]:
show_ind_img(ds_coco,8)

In [None]:
torch_ds[8][1]

In [None]:
show_mask_bb(torch_ds, 22130)

In [None]:
def test_dataset(torch_ds):
    

def test_mAP(ds):
    mAP = MeanAveragePrecision(
                box_format='xyxy',
                iou_thresholds=None,
                rec_thresholds=[1, 10, 100],
                class_metrics=False,
                )

    metr.update(trgt, trgt)
    pprint(metr.compute())

In [None]:
test_mAP()

In [None]:
# a = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model_ft.transform()

In [None]:
z = model_ft.transform(torch.unsqueeze(train_ds[1000][0], dim=0))[0]
z.image_sizes# train_ds[0][0].shape

In [None]:
train_ds[1000][0].shape

In [None]:
grcnn = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=700, max_size=700, image_mean=[0.485], image_std=[0.229])
model_ft.transform = grcnn

for el in train_ds:
    z = model_ft.transform(torch.unsqueeze(el[0], dim=0))[0]
    print(z.image_sizes)# train_ds[0][0].shape


In [None]:
im_val.shape

In [None]:
model_ft.cpu().eval()

raw_val = [torch_ds[i][0] for i in range(0,10)]
trgt = [torch_ds[i][1] for i in range(0,10)]
im_val = [torch.mul(255, raw_val[i][0]) for i in range(0,10)]

res = model_ft(raw_val)
pprint(res)

In [None]:
metr = MeanAveragePrecision(
                box_format='xyxy',
                iou_thresholds=None,
                rec_thresholds=[1, 10, 100],
                class_metrics=False,
                )

metr.update(trgt, trgt)
pprint(metr.compute())

In [None]:
# create mask rcnn model
num_classes = 498

model_ft = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
in_features = model_ft.roi_heads.box_predictor.cls_score.in_features
model_ft.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model_ft.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model_ft.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
model_ft.to(DEVICE)

for param in model_ft.parameters():
    param.requires_grad = False
for param in model_ft.roi_heads.parameters():
    param.requires_grad = True


#all without backbone 19792571


print('my', sum(p.numel() for p in model_ft.parameters() if p.requires_grad))

In [None]:
type(optimizer)

# Dataset

In [None]:
train_ds = FoodDataset(TRAIN_IMAGES_PATH, TRAIN_LABELS)
val_ds = FoodDataset(VAL_IMAGES_PATH, VAL_LABELS)

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=2, shuffle=True, num_workers=6,
    collate_fn=lambda x: tuple(zip(*x)))

val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=2, shuffle=True, num_workers=6,
    collate_fn=lambda x: tuple(zip(*x)))

params = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.0005, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [None]:
wandb.init(project="food", entity="alarnti")
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 100,
  "batch_size": 16
}


In [None]:

val_score = 1e10
num_epochs = 100
model_ft.cpu()
for epoch in range(num_epochs):
    model_ft.train()
    for i_iter, (images, targets) in enumerate(tqdm(train_loader)):
        images = list(image.to(DEVICE) for image in images)
        print(images[0].shape)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model_ft(images, targets)
    
        losses_detached = {key: l.cpu().detach().numpy() for key, l in loss_dict.items()}

        loss_mask = losses_detached['loss_mask']
        loss_objectness = losses_detached['loss_objectness']
        loss_rpn_box_reg = losses_detached['loss_rpn_box_reg']
        loss_classifier = losses_detached['loss_classifier']
        loss_box_reg = losses_detached['loss_box_reg']
        losses = sum(loss for loss in loss_dict.values())
        
#         wandb.log({
#                     "loss_mask": loss_mask,
#                     "loss_objectness": loss_objectness,
#                     "loss_rpn_box_reg": loss_rpn_box_reg,
#                     "loss_classifier": loss_classifier,
#                     "loss_box_reg": loss_box_reg,
#                     "all_losses": losses.cpu().detach().numpy()})

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    mean_val_loss = 0
    for i_iter, (images, targets) in enumerate(tqdm(val_loader)):
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model_ft(images, targets)
        losses_detached = {key: l.cpu().detach().numpy() for key, l in loss_dict.items()}

        loss_mask = losses_detached['loss_mask']
        loss_objectness = losses_detached['loss_objectness']
        loss_rpn_box_reg = losses_detached['loss_rpn_box_reg']
        loss_classifier = losses_detached['loss_classifier']
        loss_box_reg = losses_detached['loss_box_reg']
        losses = sum(loss for loss in loss_dict.values())
        
        mean_val_loss += losses
        
#         wandb.log({
#                     "loss_mask_val": loss_mask,
#                     "loss_objectness_val": loss_objectness,
#                     "loss_rpn_box_reg_val": loss_rpn_box_reg,
#                     "loss_classifier_val": loss_classifier,
#                     "loss_box_reg_val": loss_box_reg,
#                     "all_losses_val": losses.cpu().detach().numpy()})
    
    mean_val_loss /= len(val_loader)
        
#     wandb.log({'mean_val_loss', mean_val_loss})
    lr_scheduler.step(mean_val_loss)
    
    if mean_val_loss < val_score:
        torch.save(model_ft.state_dict(), 
                   MODEL_SAVE_PATH + 'maskrcnn_' + epoch + '_' + 'val_' + str(mean_val_loss))
        val_score = mean_val_loss
    


In [None]:
def do_validation(model, val_loader):
    mean_val_loss = 0
    metr = MeanAveragePrecision(
                    box_format='xyxy',
                    iou_thresholds=None,
                    rec_thresholds=[1, 10, 100],
                    class_metrics=False,
                    )



    with torch.no_grad():
        for i_iter, (images, targets) in enumerate(tqdm(val_loader)):
            model.train()
            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses_detached = {key: l.cpu().detach().numpy() for key, l in loss_dict.items()}

            loss_mask = losses_detached['loss_mask']
            loss_objectness = losses_detached['loss_objectness']
            loss_rpn_box_reg = losses_detached['loss_rpn_box_reg']
            loss_classifier = losses_detached['loss_classifier']
            loss_box_reg = losses_detached['loss_box_reg']
            losses = sum(loss for loss in loss_dict.values())

            mean_val_loss += losses
            
            pprint(losses_detached)
            
            model.eval()
            res = model(images)
            metr.update(res, targets)
            
            pprint(res)
            
            break
            if i_iter > 100:
                break
            
    
    mean_val_loss /= len(val_loader)
    
    pprint(metr.compute())    

In [None]:
model_ft.load_state_dict(torch.load("maskrcnn_9_val_tensor(0.7673, device='cuda_0')"))
model_ft.cuda()

In [None]:

# do_validation(model_ft, train_loader)

In [None]:
do_validation(model_ft, train_loader)

In [None]:
model_ft.cpu().eval()

raw_val = [torch_ds[i][0] for i in range(0,10)]
trgt = [torch_ds[i][1] for i in range(0,10)]
im_val = [torch.mul(255, raw_val[i][0]) for i in range(0,10)]

res = model_ft(raw_val)
pprint(res)