In [None]:
import os
import torch
import torchvision
from torchvision import transforms
from engine import train_one_epoch, evaluate
import utils
from coco import CocoSubset

root = 'E:/Resource/Dataset/COCO/SubCOCO'
annDir = os.path.join(root, 'annotations/instances_{}.json')
# coco = COCO(annDir.format('train2017'))

## Creating Datasets

In [None]:
img_transform = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(0.5)
    ]),
    'val':transforms.Compose([
        transforms.ToTensor()
    ])}

target_transform = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(0.5)
    ]),
    'val':transforms.Compose([])}

coco_train = CocoSubset(os.path.join(root, 'train2017'),
                        annDir.format('train2017'),
                        img_transform=img_transform['train'],
                        target_transform=target_transform['train'])

coco_val = CocoSubset(os.path.join(root, 'val2017'),
                      annDir.format('val2017'),
                      img_transform=img_transform['val'],
                      target_transform=target_transform['val'])
print('Amount of train images:')
print(coco_train)
print('Amount of validation images:')
print(coco_val)

In [None]:
print(len(coco_train))
print(type(coco_train))
print(coco_train[0])

In [None]:
data_loader_train = torch.utils.data.DataLoader(coco_train,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=1,
                                         collate_fn=utils.collate_fn)
data_loader_val = torch.utils.data.DataLoader(coco_val,
                                              batch_size=8,
                                              shuffle=False,
                                              num_workers=1,
                                              collate_fn=utils.collate_fn)

In [None]:
data_loader_train

## Defining Model

In [None]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
torch.manual_seed(1)

def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                      hidden_layer,
                                                      num_classes)
    return model

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 5

model = get_instance_segmentation_model(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=0.005,
                            momentum=0.9,
                            weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                              step_size=3,
                                              gamma=0.1)

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, data_loader_val, device=device)

In [None]:
for i in range(2):
    print(i)