In [1]:
import os
import torch
import torchvision
from torchvision import transforms
from engine import train_one_epoch, evaluate
import utils
from coco import CocoSubset

root = 'E:/Resource/Dataset/COCO/SubCOCO'
annDir = os.path.join(root, 'annotations/instances_{}.json')
# coco = COCO(annDir.format('train2017'))

loading annotations into memory...
Done (t=1.85s)
creating index...
index created!
loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
Amount of train images:
Dataset CocoSubset
    Number of datapoints: 19759
    Root location: E:/Resource/Dataset/COCO/SubCOCO\train2017
Amount of validation images:
Dataset CocoSubset
    Number of datapoints: 870
    Root location: E:/Resource/Dataset/COCO/SubCOCO\val2017
19759
<class 'coco.CocoSubset'>
<PIL.Image.Image image mode=RGB size=480x640 at 0x26036A33108> {'labels': tensor([3, 8]), 'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 

## Creating Datasets

In [2]:
img_transform = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(0.5)
    ]),
    'val':transforms.Compose([
        transforms.ToTensor()
    ])}

target_transform = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(0.5)
    ]),
    'val':transforms.Compose([])}

coco_train = CocoSubset(os.path.join(root, 'train2017'),
                        annDir.format('train2017'),
                        img_transform=img_transform['train'],
                        target_transform=target_transform['train'])

coco_val = CocoSubset(os.path.join(root, 'val2017'),
                      annDir.format('val2017'),
                      img_transform=img_transform['val'],
                      target_transform=target_transform['val'])
print('Amount of train images:')
print(coco_train)
print('Amount of validation images:')
print(coco_val)

loading annotations into memory...
Done (t=2.04s)
creating index...
index created!
loading annotations into memory...
Done (t=0.08s)
creating index...
index created!
Amount of train images:
Dataset CocoSubset
    Number of datapoints: 19759
    Root location: E:/Resource/Dataset/COCO/SubCOCO\train2017
Amount of validation images:
Dataset CocoSubset
    Number of datapoints: 870
    Root location: E:/Resource/Dataset/COCO/SubCOCO\val2017


In [3]:
print(len(coco_train))
print(type(coco_train))
print(coco_train[0])

19759
<class 'coco.CocoSubset'>
(tensor([[[1.0000, 1.0000, 1.0000,  ..., 0.5412, 0.5176, 0.5216],
         [1.0000, 1.0000, 1.0000,  ..., 0.5294, 0.5176, 0.5451],
         [1.0000, 1.0000, 1.0000,  ..., 0.5216, 0.5098, 0.5373],
         ...,
         [0.5490, 0.5686, 0.6549,  ..., 0.4157, 0.4902, 0.4196],
         [0.5647, 0.6471, 0.7020,  ..., 0.4275, 0.4196, 0.4392],
         [0.6314, 0.6824, 0.6157,  ..., 0.4078, 0.4549, 0.4784]],

        [[1.0000, 1.0000, 1.0000,  ..., 0.5176, 0.5333, 0.4941],
         [1.0000, 1.0000, 1.0000,  ..., 0.5059, 0.5216, 0.5098],
         [1.0000, 1.0000, 1.0000,  ..., 0.5059, 0.5098, 0.5020],
         ...,
         [0.5059, 0.5647, 0.6039,  ..., 0.3569, 0.4235, 0.2902],
         [0.5098, 0.6392, 0.6510,  ..., 0.3529, 0.4196, 0.2980],
         [0.6078, 0.6745, 0.5843,  ..., 0.3490, 0.4667, 0.3059]],

        [[1.0000, 1.0000, 1.0000,  ..., 0.5255, 0.5373, 0.5216],
         [1.0000, 1.0000, 1.0000,  ..., 0.5137, 0.5294, 0.5294],
         [1.0000, 1.0000,

In [4]:
data_loader_train = torch.utils.data.DataLoader(coco_train,
                                         batch_size=4,
                                         shuffle=True,
                                         num_workers=1,
                                         collate_fn=utils.collate_fn)
data_loader_val = torch.utils.data.DataLoader(coco_val,
                                              batch_size=4,
                                              shuffle=False,
                                              num_workers=1,
                                              collate_fn=utils.collate_fn)

In [5]:
data_loader_train

<torch.utils.data.dataloader.DataLoader at 0x2604b686248>

## Defining Model

In [6]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
torch.manual_seed(1)

def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                      hidden_layer,
                                                      num_classes)
    return model

In [9]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')

num_classes = 5

model = get_instance_segmentation_model(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=0.005,
                            momentum=0.9,
                            weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                              step_size=3,
                                              gamma=0.1)

In [10]:
num_epochs = 1

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, data_loader_val, device=device)

IndexError: Target 8 is out of bounds.