In [9]:
from data import RooftopDataset
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch
from torch import optim

In [4]:
data_root_path = 'images'

In [5]:
dataset = RooftopDataset(root=data_root_path, grayscale=True, transforms=None)
img, mask = dataset[0]
print(mask.size(), img.size())

torch.Size([1, 256, 256]) torch.Size([1, 256, 256])


In [7]:
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (rooftop) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [10]:
def collate_fn(batch):
    return tuple(zip(*batch))

# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
to_grayscale = True

dataset = RooftopDataset(root=data_root_path, grayscale=to_grayscale)

dataset_test = RooftopDataset(root=data_root_path, transforms=None)

dataset = torch.utils.data.Subset(dataset, indices[:-2])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-2:])

data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=True, num_workers=4,
        collate_fn=collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=4,
        collate_fn=collate_fn
)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.001)

In [11]:
def train_one_epoch(model, optimizer, data_loader):
    model.train()

    for i, (images, masks) in enumerate(data_loader):
        loss_dict = model(images, masks)

        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

In [None]:
# let's train it for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader)
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")