In [None]:
# Install required libs
!pip install -U segmentation-models-pytorch albumentations --user
# !pip uninstall -y segmentation-models-pytorch

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import matplotlib.pyplot as plt
import linzhutils as lu
from segdataset import Dataset
from torch.utils.data import DataLoader

DATA_DIR = '../data/hel2019/'

**If you run this code for the first time after downloading, please uncomment the following codes and run them. However, if it is not first time, please do not run them.**

In [None]:
if len(lu.getAllFileList(os.path.join(DATA_DIR,'val_images')))>0 and len(lu.getAllFileList(os.path.join(DATA_DIR,'val_bin_masks')))>0:
    print('Validation data already exists.')
else:
    lu.splitMultiData([os.path.join(DATA_DIR,'images'), os.path.join(DATA_DIR,'bin_masks')],
                    [os.path.join(DATA_DIR,'val_images'), os.path.join(DATA_DIR,'val_bin_masks')], 0.8)
!pwd

In [None]:
img_dir = os.path.join(DATA_DIR, 'images')
anno_dir = os.path.join(DATA_DIR, 'bin_masks')

val_img_dir = os.path.join(DATA_DIR, 'val_images')
val_anno_dir = os.path.join(DATA_DIR, 'val_bin_masks')

In [None]:
import albumentations as albu

def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=512, min_width=512, always_apply=True, border_mode=0),
        albu.RandomCrop(height=512, width=512, always_apply=True),

        albu.GaussNoise(p=0.2),
        albu.Perspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightnessContrast(p=1),
                albu.RandomGamma(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Sharpen(p=1),
                # albu.Blur(blur_limit=2, p=1),
                # albu.MotionBlur(blur_limit=2, p=1),
            ],
            p=0.5,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(512, 512)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

ENCODER = 'timm-efficientnet-b8'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['kussi']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# create segmentation model with pretrained encoder
# MODEL_NAME = 'FPN'
model = smp.PSPNet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = Dataset(
    img_dir, 
    anno_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    val_img_dir, 
    val_anno_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=1)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=1)

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

# loss = smp.utils.losses.DiceLoss()
from segmentation_models_pytorch import utils as smpu

loss = smp.losses.DiceLoss('binary')
loss.__name__ = 'focal_loss'
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs
import wandb
import random
# wandb.init(project='forest_patches', entity='linzh', name=f'{MODEL_NAME}-{int(random.random()*1000)}')
# wandb.watch(model, log='all')

max_score = 0

train_log_list = []
valid_log_list = []

for i in range(0, 25):
    print('\nEpoch: {}'.format(i+1))
    train_logs = train_epoch.run(train_loader)
    train_log_list.append(train_logs)
    print(train_logs)
    valid_logs = valid_epoch.run(valid_loader)
    valid_log_list.append(valid_logs)
    print(valid_logs)
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
    if i == 15:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')
    # do something (save model, change lr, etc.)
print(f'Max IOU: {max_score}')

In [None]:
print(f'Max IOU: {max_score}')