In [1]:
import torch
import torchvision
import lvis
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from constants import *

## Dataset


In [2]:
from lvis import LVIS


class FilteredLVIS(LVIS):
    """
    Lighter version of LVIS which drops any unwanted data, filters over specific category ids,
    keeping only annotations of these categories and images containing these annotations.
    For RAM efficiency
    """

    def __init__(self, annotation_path, cat_names):
        super().__init__(annotation_path)
        cat_ids = self._get_cat_ids(cat_names)
        self.trim(cat_ids)

    def trim(self, cat_ids):
        cat_ids = set(cat_ids)
        ann_ids = set(self.get_ann_ids(cat_ids=cat_ids))
        img_ids = set(self._get_img_ids(cat_ids))
        self.anns = {k: v for k, v in self.anns.items() if k in ann_ids}
        self.cats = {k: v for k, v in self.cats.items() if k in cat_ids}
        self.imgs = {k: v for k, v in self.imgs.items() if k in img_ids}
        self.img_ann_map = {
            iid: [ann for ann in anns if ann['id'] in ann_ids]
            for iid, anns in self.img_ann_map.items()
            if iid in img_ids
        }
        self.cat_img_map = {
            cid: [iid for iid in imgs if iid in img_ids]
            for cid, imgs in self.cat_img_map.items()
            if cid in cat_ids
        }
        self.dataset = {}
        self.dataset['annotations'] = list(self.anns.values())
        self.dataset['images'] = list(self.imgs.values())
        self.dataset['categories'] = list(self.cats.values())
        
        self._create_index()

    def _get_img_ids(self, cat_ids):
        return list({
            iid for cat_id in cat_ids
            for iid in self.cat_img_map[cat_id]
        })

    def _get_cat_ids(self, names):
        cat_ids = []
        cat_names = []
        for id, cat in self.cats.items():
            if cat['name'] in names:
                cat_ids.append(id)
                cat_names.append(cat['name'])
        print(f'category found for {[name for name in names if name in cat_names]}')
        print(f'category NOT found for {[name for name in names if name not in cat_names]}\n')
        return list(set(cat_ids))

In [3]:
from torch.utils.data import Dataset
from torchvision.io import decode_image
from torchvision import tv_tensors
from torchvision.ops import box_convert
from pycocotools import mask as maskUtils
from typing import List
import sys


# with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html


class LVISDataset(Dataset):
    def __init__(self, coco2017_dirs: List[str], lvis_gt: lvis.LVIS | FilteredLVIS, transforms=None,
                 cat_ids=None) -> None:
        self.images_dirs = coco2017_dirs
        self.lvis_gt = lvis_gt
        self.transforms = transforms
        self.img_ids = self.lvis_gt.get_img_ids() if cat_ids is None else self._get_img_ids(cat_ids)
        self.cat_ids = self.lvis_gt.get_cat_ids() if cat_ids is None else cat_ids
        self.cat_id_to_label = {cat_id: i + 1 for i, cat_id in enumerate(self.cat_ids)}

    def _get_img_ids(self, cat_ids):
        return list({
            iid for cat_id in cat_ids
            for iid in self.lvis_gt.cat_img_map[cat_id]
        })

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx: int):
        ### get image
        img_id = self.img_ids[idx]
        img = None
        for images_dir in self.images_dirs:  # to handle LVIS val dataset using both coco train and val
            try:
                image_path = os.path.join(images_dir, f'{str(img_id).zfill(12)}.jpg')
                img = decode_image(image_path).float() / 255.0
                break
            except Exception as e:
                # j'ai rajouté le print de l'erreur mais en fait au moment de l'eval ça print quand il trouve pas dans le premier dirdc pê changer
                print(e) 
                pass
        if img is None:
            print(f"image not found")
            sys.exit(1)
        img = tv_tensors.Image(img)
        _, h, w = img.shape

        ### get annotationss
        annot_ids = self.lvis_gt.get_ann_ids(img_ids=[img_id])
        annots = self.lvis_gt.load_anns(annot_ids)
        annots = [annot for annot in annots if annot['category_id'] in self.cat_ids]
        # labels
        labels = torch.tensor([self.cat_id_to_label[annot['category_id']] for annot in annots])
        # area
        areas = torch.tensor([annot['area'] for annot in annots])
        # boxes
        boxes = torch.tensor([annot['bbox'] for annot in annots], dtype=torch.float32)
        boxes_xyxy = box_convert(boxes, in_fmt='xywh', out_fmt='xyxy')
        boxe_tv = tv_tensors.BoundingBoxes(boxes_xyxy, format='XYXY', canvas_size=(h, w))
        # masks
        # segmentations = [annot['segmentation'] for annot in annots]
        # masks = []
        # for seg in segmentations:
        #     rle = maskUtils.merge(maskUtils.frPyObjects(seg, h, w))
        #     mask = maskUtils.decode(rle)
        #     masks.append(torch.from_numpy(mask))
        # mask_tv = tv_tensors.Mask(torch.stack(masks))

        #JE pense plus safe d'utiliser directement la fonction créer par le package LVIS (Victor)
        masks = [self.lvis_gt.ann_to_mask(ann) for ann in annots]
        masks_np = np.stack(masks, axis=0)  # shape: (N, H, W)

        mask_tv = tv_tensors.Mask(torch.from_numpy(masks_np))

        target = {}
        target['image_id'] = img_id
        target['labels'] = labels
        target['area'] = areas
        target['boxes'] = boxe_tv
        target['masks'] = mask_tv
        # print("IMG:", img.shape)
        # print("BOXES:", target["boxes"].shape)
        # print("MASKS:", target["masks"].shape)

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

In [4]:
print(COCO2017_TRAIN_PATH)
print(COCO2017_VAL_PATH)
print(COCO2017_TEST_PATH)



../coco2017/train2017/
../coco2017/val2017/
../coco2017/test2017/


## Model

In [5]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


# with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html


# TODO check (sample code)
def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model


## Utils

#### Data

In [6]:
from torchvision.transforms import v2 as T
from typing import Tuple


def get_transform(img_size: Tuple[int, int] | None = None):
    transforms = []
    if img_size is not None:
        transforms.append(T.Resize(img_size))
    transforms.append(T.ToDtype(torch.float32, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)


def custom_collate_fn(batch):
    """
    Custom collate function to prevent stacking of images with different shapes
    """
    return tuple(zip(*batch))


def get_cat_ids(cats, names):
    """
    Get categories ids with specific names
    """
    cat_ids = []
    cat_names = []
    for id, cat in cats.items():
        if cat['name'] in names:
            cat_ids.append(id)
            cat_names.append(cat['name'])
    print(f'category found for {[name for name in names if name in cat_names]}')
    print(f'category NOT found for {[name for name in names if name not in cat_names]}\n')
    return cat_ids

#### Plot


In [7]:
def makeImagesPlotReady(msk, shape):
    return (np.rot90(np.rot90(np.rot90((msk.reshape(shape[0], shape[1], shape[2])).T))))


def show_image(image):
    _image = image if image.ndim == 3 else image[0]
    _, w, h = _image.size()
    plt.figure(figsize=(5 * (w / h), 5 * (h / w)))
    _image = makeImagesPlotReady(_image, _image.size())
    plt.imshow(_image)
    plt.axis('off')
    plt.show()


def show_image_list(images):
    n_images = len(images)
    images = [makeImagesPlotReady(img, img.size()) for img in images]
    _, axs = plt.subplots(1, n_images, figsize=(3 * n_images, 3))
    axs = axs.flatten()
    for i in range(n_images):
        axs[i].imshow(images[i])
        axs[i].axis('off')
    plt.show()

#### Training

In [8]:
import math


def train_one_epoch(model, optimizer, data_loader, device, print_freq=1):
    print("debug 30")
    model.train()
    print("debug 31")
    for i, (images, targets) in tqdm(enumerate(data_loader)):
        optimizer.zero_grad()
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if i % print_freq == 0:
            print(loss_value)

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict)
            sys.exit(1)
    print("debug 32")

    return loss_value


from lvis import LVISResults
from lvis import LVISEval
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

# TODO not working properly
@torch.inference_mode()
def evaluate(model, data_loader, lvis_gt, device):
    print("debug 20")
    model.eval()
    outputs = []
    print("debug 21")
    global_idx = 0
    for images, targets in tqdm(data_loader):
        images = list(image.to(device) for image in images)
        output = model(images)
        
        output = [{k: v.to(torch.device(device)) for k, v in t.items()} for t in output]

        for i, t in enumerate(output):
            t["image_id"] = int(targets[i]["image_id"])
            t["id"] = global_idx
            global_idx += 1
        outputs.extend(output)
        # outputs["category_id"]=int(outputs.pop("labels"))
    print("debug 22")
    print(f"type output {type(outputs)}")
    print(f"len outputs {len(outputs)}")
    print(f"len output[0]{len(outputs[0])}")
    print(f"outputs[0].keys() {outputs[0].keys()}")
    print("test 1")
    

    coco_gt = COCO()
    coco_gt.dataset = {
    "info": {},
    "images": lvis_gt.dataset['images'],
    "annotations": lvis_gt.dataset['annotations'],
    "categories": lvis_gt.dataset['categories']
}
    coco_dt = coco_gt.loadRes(outputs)
    
    # évaluation
    # coco_eval = COCOeval(coco_gt, coco_dt, iouType='bbox')  # ou 'segm' si tu veux les masques
    # coco_eval.evaluate()
    # coco_eval.accumulate()
    # coco_eval.summarize()
    # 
    lvis_dt = LVISResults(lvis_gt, outputs)
    print("test 2")
    lvis_eval = LVISEval(lvis_gt, lvis_dt)
    print("debug 23")
    lvis_eval.run()
    lvis_eval.print_results()

## Pipeline

In [9]:
# Data args
CATEGORIES = ['cat', 'dog']  #, 'cow', 'pigeon', 'giraffe', 'bear']

#We absolutely need a resize or use a btach of one
IMG_SIZE = (256, 256)  # if None: no resize

# Learning args
BATCH_SIZE = 5
EPOCHS = 1
DEVICE = 'cpu'

In [10]:
from lvis import LVIS

lvis_gt_train = FilteredLVIS(TRAIN_ANNOT_PATH, CATEGORIES)
lvis_gt_val = FilteredLVIS(VAL_ANNOT_PATH, CATEGORIES)

category found for ['cat', 'dog']
category NOT found for []

category found for ['cat', 'dog']
category NOT found for []



In [11]:
from torch.utils.data import DataLoader, Subset

cat_ids = get_cat_ids(lvis_gt_train.cats, CATEGORIES)
num_classes = len(cat_ids) + 1  # +1 for background

train_dataset = LVISDataset([COCO2017_TRAIN_PATH], lvis_gt_train, get_transform(IMG_SIZE), cat_ids=cat_ids)
subset_train_size = list(range(10))
subset_train_dataset = Subset(train_dataset, subset_train_size)

train_loader = DataLoader(subset_train_dataset,
                          batch_size=BATCH_SIZE,
                          collate_fn=custom_collate_fn,
                          shuffle=True,
                          # num_workers=4
                          )
print(f"Size of train dataset: {len(train_dataset)}")
print(f"Size of subset train dataset: {len(subset_train_dataset)}")

val_dataset = LVISDataset([COCO2017_VAL_PATH, COCO2017_TRAIN_PATH], lvis_gt_val, get_transform(IMG_SIZE),
                          cat_ids=cat_ids)
subset_val_size = list(range(15))
subset_val_dataset = Subset(val_dataset, subset_val_size)
val_loader = DataLoader(subset_val_dataset,
                        batch_size=BATCH_SIZE,
                        collate_fn=custom_collate_fn,
                        # num_workers=4
                        )
print(f"Size of validation dataset: {len(val_dataset)}\n")
print(f"Size of subset validation dataset: {len(subset_val_dataset)}\n")

category found for ['cat', 'dog']
category NOT found for []

Size of train dataset: 3819
Size of subset train dataset: 10
Size of validation dataset: 745

Size of subset validation dataset: 15



In [12]:
device = torch.device(DEVICE)

model = get_model_instance_segmentation(num_classes)
print("debug 7")
model.to(device)
print("debug 8")

# TODO check (sample code)
params = [p for p in model.parameters() if p.requires_grad]
print("debug 9")
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)
print("debug 10")
# TODO check (sample code)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)
print("debug 11")
# TODO
for epoch in range(EPOCHS):
    loss_value = train_one_epoch(model, optimizer, train_loader, device)

    # lr_scheduler.step()
    print(loss_value)


debug 7
debug 8
debug 9
debug 10
debug 11
debug 30
debug 31


1it [00:49, 49.03s/it]

4.283474445343018


2it [01:27, 43.61s/it]

4.417148113250732
debug 32
4.417148113250732





In [13]:

        
    for i in lvis_gt_val.dataset['annotations'] :
        print(i.keys())
     

dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])
dict_keys(['area', 'id', 'segmentation', 'image_id', 'bbox', 'category_id'])

In [14]:
evaluate(model, val_loader, lvis_gt_val, device)

debug 20
debug 21


  0%|          | 0/3 [00:00<?, ?it/s]

[Errno 2] No such file or directory: '../coco2017/val2017/000000055296.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000073729.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000047105.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000129026.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000413704.jpg'


 33%|███▎      | 1/3 [00:12<00:25, 12.68s/it]

[Errno 2] No such file or directory: '../coco2017/val2017/000000417802.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000229387.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000272396.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000452619.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000421903.jpg'


 67%|██████▋   | 2/3 [00:23<00:11, 11.59s/it]

[Errno 2] No such file or directory: '../coco2017/val2017/000000026645.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000495637.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000028698.jpg'
[Errno 2] No such file or directory: '../coco2017/val2017/000000276507.jpg'


100%|██████████| 3/3 [00:35<00:00, 11.68s/it]


debug 22
type output <class 'list'>
len outputs 15
len output[0]6
outputs[0].keys() dict_keys(['boxes', 'labels', 'scores', 'masks', 'image_id', 'id'])
test 1
Loading and preparing results...


AssertionError: Results do not correspond to current coco set

# --- tests ---

In [None]:
images, targets = next(iter(train_loader))
show_image_list(images[:5])

In [None]:
from lvis import LVIS, LVISVis

lvis_gt = LVIS(VAL_ANNOT_PATH)

lvis_vis = LVISVis(lvis_gt, img_dir=COCO2017_VAL_PATH)

lvis_vis.vis_img(285)

In [None]:
# model training output
model = get_model_instance_segmentation(num_classes)
print("debug 1")
images, targets = next(iter(train_loader))
print("debug 2")
images = list(image for image in images)
print("debug 3")
targets = [{k: v for k, v in t.items()} for t in targets]
print("debug 4")
output = model(images, targets)  # Returns losses and detections
print("debug 5")
print(output)

In [None]:
# model inference output
@torch.inference_mode()
def test(model) :
    model.eval()
    print("debug 1")
    x = [torch.rand(3, 300, 400).to(device), torch.rand(3, 500, 400).to(device)]
    print("debug 2")
    predictions = model(x)  # Returns predictions
    print("debug 3")
    print(predictions[0])

test(model)

In [None]:
# evaluate test

from torch.utils.data import Subset

device = torch.device(DEVICE)
val_subset = Subset(val_dataset, indices=range(10))
val_subset_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
model = get_model_instance_segmentation(num_classes)
model.to(device)
for images, targets in val_loader:
    print(type(targets), targets)
    break
    
print("collate_fn =", val_loader.collate_fn)



evaluate(model, val_loader, lvis_gt_val, device)

In [None]:
import cv2

img = cv2.imread("../coco2017/train2017/000000055296.jpg")
if img is None : 
    print("erreur imgage pas found")
cv2.imshow("test", img)
cv2.waitKey(0)
cv2.destroyAllWindows()


In [None]:
import json

with open("data/lvis_v1_val.json", "r") as f:
    data = json.load(f)


category_fields = list(data["categories"][0].keys())

print(category_fields)
