In [49]:
from pycocotools.coco import COCO
import numpy as np
import random
import os
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from skimage.transform import resize
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
%matplotlib inline

In [44]:
annFile='coco/annotations_trainval2017/annotations/instances_train2017.json'

In [45]:
coco=COCO(annFile)

loading annotations into memory...
Done (t=74.29s)
creating index...
index created!


In [46]:
catIDs = coco.getCatIds()
cats = coco.loadCats(catIDs)
img_Ids = coco.getImgIds(catIds=44)

In [48]:
len(img_Ids)

8501

In [25]:
from torchvision import datasets, models, transforms
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

In [52]:
class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, img_id, coco):
        'Initialization'
        self.img_id = img_id
        self.coco = coco
        self.transform_img = transforms.Compose([
            transforms.Resize((300,300)),
            #transforms.CenterCrop(528),
            transforms.ColorJitter(hue=.05, saturation=.05),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.img_id)

    def __getitem__(self, index):
        'Generates one sample of data'
        img = self.coco.loadImgs(self.img_id[index])[0]
        annIds = self.coco.getAnnIds(imgIds=img['id'], catIds=44)
        anns = self.coco.loadAnns(annIds)
        img_file = img['file_name']
        image = self.transform_img(Image.open(f'coco/train2017/{img_file}'))
        mask = np.zeros((img['height'],img['width']))
        mask = np.maximum(self.coco.annToMask(anns[0]), mask)
        mask = torch.from_numpy(resize(mask, (300,300)))
        return image, mask


In [57]:
train_imgIds, val_imgIds = train_test_split(img_Ids, test_size=0.15)

In [58]:
train_dataset = Dataset(train_imgIds, coco)
val_dataset = Dataset(val_imgIds, coco)

In [59]:
train_loader = DataLoader(train_dataset, batch_size=6, num_workers=os.cpu_count(),
        drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=6, num_workers=os.cpu_count(),
        drop_last=True)

In [60]:
for x,y in train_loader:
    print(x.shape)
    print(y.shape)
    break

torch.Size([6, 3, 300, 300])
torch.Size([6, 300, 300])


In [61]:
for x,y in val_loader:
    print(x.shape)
    print(y.shape)
    break

torch.Size([6, 3, 300, 300])
torch.Size([6, 300, 300])


In [62]:
class SegmentationNet(nn.Module):
    def __init__(self):
        super(SegmentationNet, self).__init__()
        self.segnet = smp.DeepLabV3Plus(encoder_name="efficientnet-b3", encoder_weights="imagenet", in_channels=3, classes=1)
    def forward(self, x):
        x = self.segnet(x)
        return x

In [67]:
net = SegmentationNet()

In [68]:
EPOCHS = 10
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [None]:
best_loss = 1e8
for i in range(EPOCHS):
    total_loss = 0
    net.train()
    print("***Train***")
    for image, mask in train_loader:
        output_mask = net(image.cuda())
        loss = criterion(output_mask, mask.cuda())
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(i, total_loss.item())
    net.eval()
    total_loss = 0
    print("***Validation***")
    with torch.no_gard():
        for image, mask in val_loader:
        output_mask = net(image.cuda())
        loss = criterion(output_mask, mask.cuda())
        total_loss += loss
    val_loss = total_loss.item()
    if val_loss<best_loss:
        torch.save(net.state_dict(),"model.pth")