In [1]:
import torch
import torchvision
import lvis
import requests
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from constants import *

## Dataset


In [2]:
from lvis import LVIS

class FilteredLVIS(LVIS):
    """
    Lighter version of LVIS which drops any unwanted data, filters over specific category ids,
    keeping only annotations of these categories and images containing these annotations.
    For RAM efficiency
    """

    def __init__(self, annotation_path, cat_names):
        super().__init__(annotation_path)
        cat_ids = self._get_cat_ids(cat_names)
        self.trim(cat_ids)

    def trim(self, cat_ids):
        cat_ids = set(cat_ids)
        ann_ids = set(self.get_ann_ids(cat_ids=cat_ids))
        img_ids = set(self._get_img_ids(cat_ids))
        self.anns = {k: v for k, v in self.anns.items() if k in ann_ids}
        self.cats = {k: v for k, v in self.cats.items() if k in cat_ids}
        self.imgs = {k: v for k, v in self.imgs.items() if k in img_ids}
        self.img_ann_map = {
            iid: [ann for ann in anns if ann['id'] in ann_ids] 
            for iid, anns in self.img_ann_map.items() 
            if iid in img_ids
        }
        self.cat_img_map = {
            cid: [iid for iid in imgs if iid in img_ids] 
            for cid, imgs in self.cat_img_map.items() 
            if cid in cat_ids
        }
        self.dataset = {}
        self.dataset['annotations'] = [ann for ann in self.anns]
        self.dataset['images'] = [img for img in self.imgs]
        self.dataset['categories'] = [cat for cat in self.cats]

    def _get_img_ids(self, cat_ids):
        return list({
            iid for cat_id in cat_ids 
            for iid in self.cat_img_map[cat_id]
        })
    
    def _get_cat_ids(self, names):
        cat_ids = []
        cat_names = []
        for id, cat in self.cats.items():
            if cat['name'] in names:
                cat_ids.append(id)
                cat_names.append(cat['name'])
        print(f'category found for {[name for name in names if name in cat_names]}')
        print(f'category NOT found for {[name for name in names if name not in cat_names]}\n')
        return list(set(cat_ids))

In [3]:
from torch.utils.data import Dataset
from torchvision.io import decode_image
from torchvision import tv_tensors
from torchvision.ops import box_convert
from pycocotools import mask as maskUtils
from typing import List
from PIL import Image
from io import BytesIO

# with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html


class LVISDataset(Dataset):
    def __init__(self, coco2017_dirs: List[str], lvis_gt: lvis.LVIS|FilteredLVIS, transforms=None, cat_ids=None) -> None:
        self.images_dirs = coco2017_dirs
        self.lvis_gt = lvis_gt
        self.transforms = transforms
        self.img_ids = self.lvis_gt.get_img_ids() if cat_ids is None else self._get_img_ids(cat_ids)
        self.cat_ids = self.lvis_gt.get_cat_ids() if cat_ids is None else cat_ids
        self.cat_id_to_label = {cat_id: i+1 for i, cat_id in enumerate(self.cat_ids)}
        if  all(os.path.isdir(dir) for dir in self.images_dirs):
            self._get_image = self._get_image_from_file
            print("will load images from files")
        else:
            self._get_image = self._get_image_from_url
            print("will load images from urls")


    def _get_img_ids(self, cat_ids):
        return list({
            iid for cat_id in cat_ids 
            for iid in self.lvis_gt.cat_img_map[cat_id]
        })
    
    def _get_image_from_file(self, id):
        image_paths = [os.path.join(images_dir, f'{str(id).zfill(12)}.jpg') for images_dir in self.images_dirs]
        for image_path in image_paths:
            if os.path.isfile(image_path):
                return decode_image(image_path)
        print(f"image not found: {image_paths}")
        sys.exit(1)

    def _get_image_from_url(self, id):
        url = self.lvis_gt.imgs[id]['coco_url']
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        return img
    
    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx: int):
        ### get image
        img_id = self.img_ids[idx]
        img = self._get_image(img_id)
        img = tv_tensors.Image(img)
        _,h,w = img.shape

        ### get annotationss
        annot_ids = self.lvis_gt.get_ann_ids(img_ids=[img_id])
        annots = self.lvis_gt.load_anns(annot_ids)
        annots = [annot for annot in annots if annot['category_id'] in self.cat_ids]
        # labels
        labels = torch.tensor([self.cat_id_to_label[annot['category_id']] for annot in annots])
        # area
        areas = torch.tensor([annot['area'] for annot in annots])
        # boxes
        boxes = torch.tensor([annot['bbox'] for annot in annots], dtype=torch.float32)
        boxes_xyxy = box_convert(boxes, in_fmt='xywh', out_fmt='xyxy')
        boxe_tv = tv_tensors.BoundingBoxes(boxes_xyxy, format= 'XYXY', canvas_size=(h,w))
        # masks
        segmentations = [annot['segmentation'] for annot in annots]
        masks = []
        for seg in segmentations:
            rle = maskUtils.merge(maskUtils.frPyObjects(seg, h, w))
            mask = maskUtils.decode(rle)
            masks.append(torch.from_numpy(mask))
        mask_tv = tv_tensors.Mask(torch.stack(masks))

        target = {}
        target['image_id'] = img_id
        target['labels'] = labels
        target['area'] = areas
        target['boxes'] = boxe_tv
        target['masks'] = mask_tv

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

## Model

In [4]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html


# TODO check (sample code)
def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

## Utils

#### Data

In [5]:
from torchvision.transforms import v2 as T
from typing import Tuple


def get_transform(img_size: Tuple[int, int]|None = None):
    transforms = []
    if img_size is not None:
        transforms.append(T.Resize(img_size))
    transforms.append(T.ToDtype(torch.float32, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)


def custom_collate_fn(batch):
    """
    Custom collate function to prevent stacking of images with different shapes
    """
    return tuple(zip(*batch))


def get_cat_ids(cats, names):
    """
    Get categories ids with specific names
    """
    cat_ids = []
    cat_names = []
    for id, cat in cats.items():
        if cat['name'] in names:
            cat_ids.append(id)
            cat_names.append(cat['name'])
    print(f'category found for {[name for name in names if name in cat_names]}')
    print(f'category NOT found for {[name for name in names if name not in cat_names]}\n')
    return cat_ids

#### Plot


In [6]:
def makeImagesPlotReady(msk, shape):
    return (np.rot90(np.rot90(np.rot90((msk.reshape(shape[0],shape[1],shape[2])).T))))


def show_image(image): 
    _image = image if image.ndim == 3 else image[0]
    _, w, h = _image.size()
    plt.figure(figsize=(5*(w/h),5*(h/w)))
    _image = makeImagesPlotReady(_image, _image.size())
    plt.imshow(_image)
    plt.axis('off')
    plt.show()


def show_image_list(images):
    n_images = len(images)
    images = [makeImagesPlotReady(img, img.size()) for img in images]
    _, axs = plt.subplots(1, n_images, figsize=(3*n_images, 3))
    axs = axs.flatten()
    for i in range(n_images):
        axs[i].imshow(images[i])
        axs[i].axis('off')
    plt.show()

#### Training

In [7]:
import math

def train_one_epoch(model, optmizer, data_loader, device, print_freq=1):
    model.train()

    for i, (images, targets) in tqdm(enumerate(data_loader)):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        losses.backward()
        optmizer.step()
        if i%print_freq==0:
            print(loss_value)
            
        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict)
            sys.exit(1)

    return loss_value


from lvis import LVISResults
from lvis import LVISEval

# TODO not working properly
@torch.inference_mode()
def evaluate(model, data_loader, lvis_gt, device):
    model.eval()
    outputs = []
    for images, targets in tqdm(data_loader):
        images = list(image.to(device) for image in images)
        output = model(images)
        output = [{k: v.to(torch.device('cpu')) for k, v in t.items()} for t in output]
        output = [t | {'image_id': targets[i]['image_id']} for i, t in enumerate(output)]
        outputs.extend(output)
    print(len(outputs))
    print(len(outputs[0]))
    print(outputs[0].keys())
    lvis_dt = LVISResults(lvis_gt, outputs)
    lvis_eval = LVISEval(lvis_gt, lvis_dt)
    lvis_eval.run()
    lvis_eval.print_results()

## Pipeline

In [8]:
# Data args
CATEGORIES = ['cat', 'dog'] #, 'cow', 'pigeon', 'giraffe', 'bear']
IMG_SIZE = None     # if None: no resize

# Learning args
BATCH_SIZE = 5
EPOCHS = 2
DEVICE = 'cpu'

In [9]:
from lvis import LVIS

lvis_gt_train = FilteredLVIS(TRAIN_ANNOT_PATH, CATEGORIES)
lvis_gt_val = FilteredLVIS(VAL_ANNOT_PATH, CATEGORIES)

category found for ['cat', 'dog']
category NOT found for []

category found for ['cat', 'dog']
category NOT found for []



In [10]:
from torch.utils.data import DataLoader, Subset

cat_ids = get_cat_ids(lvis_gt_train.cats, CATEGORIES)
num_classes = len(cat_ids) + 1 # +1 for background

train_dataset = LVISDataset([COCO2017_TRAIN_PATH], lvis_gt_train, get_transform(IMG_SIZE), cat_ids= cat_ids)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn, shuffle=True)
print(f"Size of train dataset: {len(train_dataset)}")

val_dataset = LVISDataset([COCO2017_TRAIN_PATH,COCO2017_VAL_PATH], lvis_gt_val, get_transform(IMG_SIZE), cat_ids= cat_ids)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
print(f"Size of validation dataset: {len(val_dataset)}\n")

category found for ['cat', 'dog']
category NOT found for []

will load images from files
Size of train dataset: 3819
will load images from files
Size of validation dataset: 745



In [None]:
device = torch.device(DEVICE)
model = get_model_instance_segmentation(num_classes)
model.to(device)


# TODO check (sample code)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# TODO check (sample code)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# TODO
for epoch in range(EPOCHS):
    loss_value = train_one_epoch(model, optimizer, train_loader, device)
    # evaluate(model, val_loader, lvis_gt_val, device)
    lr_scheduler.step()
    print(loss_value)


0it [00:00, ?it/s]

# --- tests ---

In [None]:
images, targets = next(iter(train_loader))
show_image_list(images[:5])

In [None]:
from lvis import LVIS, LVISVis

lvis_gt = LVIS(VAL_ANNOT_PATH)

lvis_vis = LVISVis(lvis_gt, img_dir = COCO2017_VAL_PATH)

lvis_vis.vis_img(285)

In [None]:
# model training output
model = get_model_instance_segmentation(num_classes)
images, targets = next(iter(train_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)  # Returns losses and detections
print(output)

In [None]:
# model inference output
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)  # Returns predictions
print(predictions[0])

In [None]:
# evaluate test

from torch.utils.data import Subset

device = torch.device(DEVICE)
val_subset = Subset(val_dataset, indices=range(10))
val_subset_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
model = get_model_instance_segmentation(num_classes)
model.to(device)

evaluate(model, val_subset_loader, lvis_gt_val, device)