In [None]:
import os
import torch
import typing
from typing import List
import pandas as pd
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import cv2
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from __future__ import print_function

from collections import defaultdict, deque
import datetime
import pickle
import time
import torch.distributed as dist
import errno

from torch.utils.tensorboard import SummaryWriter

from torchmetrics.detection.map import MeanAveragePrecision

from pprint import pprint

from tqdm import tqdm

from src.dataset import FoodDataset
from src.vis import read_image, show_image_coco

%matplotlib inline

In [None]:
RANDOM_SEED = 42

np.random.seed(RANDOM_SEED)

TRAIN_IMAGES_PATH = 'data/public_training_set_release_2.0/images/'
TRAIN_LABELS = 'data/public_training_set_release_2.0/annotations.json'

In [None]:
labels = COCO(TRAIN_LABELS)

In [None]:
# dir(labels)

In [None]:
img_ids = labels.getImgIds()
#184135
labels.imgToAnns[img_ids[1]]

In [None]:
len(labels.getCatIds())

In [None]:
show_image_coco(img_ids[1], TRAIN_IMAGES_PATH, labels, True)

# Dataset

In [None]:
# a = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model_ft

In [None]:
model_ft.cpu().eval()
labels.loadImgs(img_ids[0])

raw_val = [train_ds[i] for i in range(0,10)]
trgt = [raw_val[i][1] for i in range(0,10)]
im_val = [torch.mul(255, raw_val[i][0]) for i in range(0,10)]
# im_val1 = [torch.from_numpy(im_val).float()]
res = model_ft(im_val)
pprint(res)

In [None]:
metr = MeanAveragePrecision(
                box_format='xyxy',
                iou_thresholds=None,
                rec_thresholds=[1, 10, 100],
                class_metrics=False,
                )

metr.update(res, trgt)
pprint(metr.compute())

In [None]:
# create mask rcnn model
num_classes = 498
device = torch.device('cuda:0')

model_ft = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
in_features = model_ft.roi_heads.box_predictor.cls_score.in_features
model_ft.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model_ft.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model_ft.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
model_ft.to(device)

for param in model_ft.parameters():
    param.requires_grad = True
for param in model_ft.backbone.parameters():
    param.requires_grad = False

In [None]:
train_ds = FoodDataset(TRAIN_IMAGES_PATH, TRAIN_LABELS)

data_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=6, shuffle=True, num_workers=8,
    collate_fn=lambda x: tuple(zip(*x)))

params = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.1)

In [None]:
writer = SummaryWriter()

num_epochs = 1
model_ft.cuda()
for epoch in range(num_epochs):
    model_ft.train()

    for i_iter, (images, targets) in enumerate(tqdm(data_loader)):
        images = list(image.to(device) for image in images)
#         print(targets)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#         print(targets)
#         break

        loss_dict = model_ft(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        
        writer.add_scalar('Loss/train', losses, i_iter)
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()