## Todo

- [x] Run the same colab as Pytorch
- [x] Code for cutting the data (data prep)
- [x] Code for creating masks
- [x] Code for visualising masks
- [ ] Build dataset class
- [ ] Remove deeplab head
- [ ] Trainer function
- [ ] Run on some small data

In [None]:
import glob
import os
from collections import defaultdict

In [None]:
# Download TorchVision repo to use some files from
# references/detection
!git clone https://github.com/pytorch/vision.git

!cp vision/references/segmentation/utils.py .
!cp vision/references/segmentation/transforms.py .
!cp vision/references/segmentation/train.py .
!cp vision/references/segmentation/coco_utils.py .

In [None]:
!pip install cython
# Install pycocotools

!pip3 install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

## 1. Just using deeplab

In [None]:
%load_ext autoreload
%autoreload 2

import torch
from PIL import Image


model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()

In [None]:
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [None]:
#filename = "../data/images_val/02_2017_0803_132452_045.jpg"
print("Opening")
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("Pre-processing")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

print("Predicting")
with torch.no_grad():
    output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)

In [None]:
classes = ['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
            'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
             'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

In [None]:
# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)

import matplotlib.pyplot as plt
plt.imshow(r)
# plt.show()

## 2. Code to crop images

This is a code to crop images into specific rectangles (3,4) by default. Skip if you have done it already.

In [None]:
input_image = Image.open(filename)

In [None]:
folder = '../data/images_val/'

In [None]:
images = glob.glob(os.path.join(folder, '*.jpg'))

In [None]:
from PIL import ImageDraw

In [None]:
input_image.getdata()

In [None]:
import tqdm

def prepare_crops(folder, cropped_folder='data/cropped',
                 crops=(4, 3)):
    # Load
    # Crop
    # Calculate masks
    images = glob.glob(os.path.join(folder, '*.jpg')) + glob.glob(os.path.join(folder, '*.JPG'))
    images += glob.glob(os.path.join(folder, '*.png'))
    images += glob.glob(os.path.join(folder, '*.PNG'))
                 
    os.makedirs(cropped_folder, exist_ok=True)
    
    for image_path in tqdm.tqdm(images):
        filename = os.path.basename(image_path)

        im = Image.open(image_path)
        crop_x = im.size[0]//crops[0]
        crop_y = im.size[1]//crops[1]

        for i in range(crops[0]):
            for j in range(crops[1]):
                im.crop([crop_x*i, crop_y*j, crop_x*(i+1), crop_y*(j+1)]).save(
                    os.path.join(cropped_folder, f'{filename[:-4]}_cropped_{i}{j}{filename[-4:]}')
                )

In [None]:
classes = [
    'c_hard_coral_branching', 
    'c_hard_coral_submassive', 
    'c_hard_coral_boulder',
    'c_hard_coral_encrusting', 
    'c_hard_coral_table', 
    'c_hard_coral_foliose',
    'c_hard_coral_mushroom', 
    'c_soft_coral', 
    'c_soft_coral_gorgonian', 
    'c_sponge', 
    'c_sponge_barrel', 
    'c_fire_coral_millepora', 
    'c_algae_macro_or_leaves'
]

name_to_id = {y: x for x, y in enumerate(classes, start=1)}

In [None]:
name_to_id

In [None]:
def prepare_masks(images_folder='../data/images_val/',
                  masks_folder='data/masks',
                  annotations_file='../data/imageCLEFcoral2020_GT.csv',
                  cropped_folder='data/cropped'):
    image_to_annotations = defaultdict(list)
    
    os.makedirs(masks_folder, exist_ok=True)
    
    with open(annotations_file, 'r') as f:
        for line in f:
            line_split = line.split(' ')
            image_path = os.path.join(images_folder, line_split[0] + '.JPG')
        
            if not os.path.exists(image_path):
                image_path = os.path.join(images_folder, line_split[0] + '.jpg')
            
            substrate = line_split[2]
            polygon = [int(x) for x in line_split[4:]]
            # Polygons are pairs of points
            polygon = [(x, y) for x, y in zip(polygon[::2], polygon[1::2])]

            image_to_annotations[image_path] += [(substrate, polygon)]
    
    for image in tqdm.tqdm(image_to_annotations.keys(), total=len(image_to_annotations)):
        filename = os.path.basename(image)
        im_size = Image.open(image).size
        # Creates a uint8 PNG
        poly = Image.new('L', size=im_size)
        pdraw = ImageDraw.Draw(poly)
        for substrate, polygon in image_to_annotations[image]:
            pdraw.polygon(polygon, fill=name_to_id[substrate])
            
        poly.save(os.path.join(masks_folder, filename[:-4] + '.png'))

In [None]:
%%time

prepare_masks()
prepare_crops('../data/images_val', cropped_folder='data/cropped/images')
prepare_crops('data/masks', cropped_folder='data/cropped/masks')

In [None]:
def visualise_mask(image_key='02_2017_0803_132446_043', masks_folder='data/masks'):
    colours = [(0, 0, 0), 
               (245, 185, 95), 
               (50, 50, 50),
               (65, 50, 230),
               (73, 74, 74),
               (78, 252, 5), 
               (186, 153, 255), 
               (200, 103, 5), 
               (198, 5, 252),
               (84, 194, 27), 
               (20, 145, 245),
               (16, 133, 16), 
               (190, 234, 98),
               (255, 233, 72)]
    
    # Linearises palette (because that's what PIL likes)
    colours_int = [x for y in colours for x in y]

    mask = Image.open(os.path.join(masks_folder, image_key + '.png'))
    mask.putpalette(colours_int)
    
    return mask

In [None]:
visualise_mask('02_2017_0803_132446_043')

In [None]:
image_key='02_2017_0803_132446_043'
masks_folder='data/'
mask = Image.open(os.path.join(masks_folder, image_key + '.png'))

## 3. Create dataset class

In [None]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image


class CoralDataset(torch.utils.data.Dataset):
    def __init__(self, images_folder='data/cropped/images',
                 masks_folder='data/cropped/masks',
                 transforms=None,
                 n_images=None):
        self.images_folder = images_folder
        self.masks_folder = masks_folder
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(images_folder)))[:n_images]
        self.masks = list(sorted(os.listdir(masks_folder)))[:n_images]

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.images_folder, self.imgs[idx])
        mask_path = os.path.join(self.masks_folder, self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        target = Image.open(mask_path)

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

## 4. Define model

Remove HEAD from last layers and add fresh ones

In [None]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)

In [None]:
model

In [None]:
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead

from torch import nn

In [None]:
# Get in channels form the convolutional layer

in_channels_head = model.classifier[0].convs[0][0].in_channels
in_channels_aux = model.aux_classifier[0].in_channels

In [None]:
model.classifier = DeepLabHead(in_channels=in_channels_head, num_classes=14)
model.aux_classifier = FCNHead(in_channels=in_channels_aux, channels=14)

## 5. Train model

In [None]:
import utils
from train import train_one_epoch, get_transform, criterion, evaluate

In [None]:
dataset = CoralDataset(transforms=get_transform(train=True), n_images=5)
dataset_test = CoralDataset(transforms=get_transform(train=False), n_images=5)

In [None]:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2,
    sampler=train_sampler, 
    collate_fn=utils.collate_fn, drop_last=True
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1,
    sampler=test_sampler,
    collate_fn=utils.collate_fn
)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 14
epochs = 10

# move model to the right device
model.to(device)

# construct an optimizer
params_to_optimize = [
    {"params": [p for p in model.backbone.parameters() if p.requires_grad]},
    {"params": [p for p in model.classifier.parameters() if p.requires_grad]},
    {"params": [p for p in model.aux_classifier.parameters() if p.requires_grad], "lr": 0.005*10}
]
optimizer = torch.optim.SGD(params_to_optimize, lr=0.005, momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda x: (1 - x / (len(data_loader) * epochs)) ** 0.9
)

In [None]:
for epoch in range(0, 3):
    train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch=epoch, print_freq=2)
    confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
    
    print(confmat)

## 5. Visualise model predictions

In [None]:
from torchvision import transforms 

filename = "data/cropped/images/02_2017_0803_132446_043_cropped_00.jpg"
print("Opening")
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("Pre-processing")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

print("Predicting")
with torch.no_grad():
    output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)