# Import package

In [None]:
#!pip install -U segmentation-models-pytorch albumentations --user 
#!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import albumentations as albu

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
import glob
import pandas as pd
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# Set and Initialize File Paths

In [None]:
def rebuild_dir(target_path):
    if os.path.exists(target_path):
        shutil.rmtree(target_path)
        os.makedirs(target_path)
    else:
        os.makedirs(target_path)

In [None]:
workspace_path = os.getcwd() 
dataset_path=os.path.join(workspace_path,'Dataset')

auto_labeling_path = os.path.join(dataset_path,'label_source')
cache_path = os.path.join(dataset_path,'cache')
cache_label_path = os.path.join(cache_path,'label')
cache_img_path = os.path.join(cache_path,'img')
rebuild_dir(cache_path)
rebuild_dir(cache_label_path)
rebuild_dir(cache_img_path)

In [None]:
# slide-2021-08-20T14-39-34-R1-S3_Wholeslide_Default_Extended_Cropped
# slide-2021-08-20T14-39-34-R1-S3_Wholeslide_Default_Extended_autolabeling

# slide-2021-11-02T20-24-15-R1-S1_Wholeslide_Default_Extended_Cropped
# slide-2021-11-02T20-24-15-R1-S1_Wholeslide_Default_Extended_autolabeling

# slide-2021-11-02T21-16-24-R1-S5_Wholeslide_Default_Extended_Cropped
# slide-2021-11-02T21-16-24-R1-S5_Wholeslide_Default_Extended_autolabeling
original_imgs_name = "slide-2021-11-02T21-16-24-R1-S5_Wholeslide_Default_Extended_Cropped"
original_labels_name = "slide-2021-11-02T21-16-24-R1-S5_Wholeslide_Default_Extended_autolabeling"

In [None]:
original_imgs = os.path.join(auto_labeling_path,original_imgs_name)
original_labels = os.path.join(auto_labeling_path,original_labels_name)
print(len(os.listdir(original_imgs)))
print(len(os.listdir(original_labels)))

In [None]:
DATA_DIR = os.path.join(dataset_path,'Autolabeling_dataset')

x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
rebuild_dir(x_train_dir)
rebuild_dir(y_train_dir)

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'val_labels')
rebuild_dir(x_valid_dir)
rebuild_dir(y_valid_dir)

ckpt_path = os.path.join(DATA_DIR,'ckpt')
if not os.path.exists(ckpt_path):
    os.makedirs(ckpt_path,exist_ok=True)

# Preprocess

## Separate training, validation, and test datasets

In [None]:
def permutation_train_test_split(data,label , valid_size=0.2, shuffle=True, random_state=1004):
    data_len=len(data)
    print(f'전체 데이터수 : {data_len}')
    valid_num=int(data_len*valid_size)
    train_num=data_len-valid_num
    
    if shuffle:
        np.random.seed(100)
        shuffled=np.random.permutation(data_len)
        data=data[shuffled]
        label=label[shuffled]
        x_train=data[:train_num]
        y_train=label[:train_num]
        x_valid=data[train_num:]
        y_valid=label[train_num:]
    else:
        x_train=data[:train_num]
        y_train=label[:train_num]
        x_valid=data[train_num:]
        y_valid=label[train_num:]

    return x_train, y_train, x_valid, y_valid

# Train, Validation 파일 나누기 (8:2)
X_path=np.array(glob.glob(original_imgs+"/*.png"))
Y_path=np.array(glob.glob(original_labels+"/*.png"))
x_train, y_train, x_valid, y_valid=permutation_train_test_split(X_path,Y_path,valid_size=0.2,shuffle=True,random_state=1004)

print('훈련 데이터 수 = img : ',len(x_train),', label : ',len(y_train))
print('검증 데이터 수 = img : ',len(x_valid),', label : ',len(y_valid))

In [None]:
for i in x_train:
    shutil.copy2(i,x_train_dir)
for i in y_train:
    shutil.copy2(i,y_train_dir)
for i in x_valid:
    shutil.copy2(i,x_valid_dir)
for i in y_valid:
    shutil.copy2(i,y_valid_dir)

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

## Define dataset

In [None]:
class Dataset(BaseDataset):
    
    CLASSES = ['x' for x in range(255)]
    CLASSES.append('cell')
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = sorted(os.listdir(images_dir))
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
dataset = Dataset(x_train_dir, y_train_dir, classes=['cell'])

image, mask = dataset[75] # get some sample
visualize(
    image=image, 
    cars_mask=mask.squeeze(),
)

## Augment dataset

In [None]:
def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        
        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
#### Visualize resulted augmented images and masks

augmented_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    classes=['cell'],
)

# same image with different random transforms
for i in range(50,55):
    image, mask = augmented_dataset[i]
    visualize(image=image, mask=mask.squeeze(-1))

# Train model

## Define encoder, activation function, model, hyper parameter

In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['cell']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
MODEL = 'Unet'
EPOCH = 50
# create segmentation model with pretrained encoder
if MODEL == 'Unet':
    model = smp.Unet(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
    )
elif MODEL == 'Unet++':
    model = smp.UnetPlusPlus(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
    )
elif MODEL == 'FPN':
    model = smp.FPN(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
    )
else:
    print('No model!')
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

## Set dataloader

In [None]:
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

## Train

In [None]:
max_score = 0
train_logs_list,valid_logs_list = [], []

for i in range(0, EPOCH):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    train_logs_list.append(train_logs)
    valid_logs_list.append(valid_logs)
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, os.path.join(ckpt_path,'{}_model.pth'.format(original_labels_name)))
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

# Predict and test

## Load best model and dataset

In [None]:
# load best saved checkpoint
best_model = torch.load(os.path.join(ckpt_path,'{}_model.pth'.format(original_labels_name)))

In [None]:
sample_path = os.path.join(ckpt_path,original_labels_name)
sample_prediction = os.path.join(sample_path,'prediction')
sample_compare = os.path.join(sample_path,'compare(original-label-img)')
rebuild_dir(sample_path)
rebuild_dir(sample_prediction)
rebuild_dir(sample_compare)

## Predict and test

In [None]:
# create test dataset
test_dataset = Dataset(
    original_imgs, original_labels, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [None]:
# # evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['dice_loss']:.4f}")

In [None]:
# test dataset without transformations for image visualization
test_dataset_vis = Dataset(
    original_imgs, original_labels, 
    classes=CLASSES,
)

In [None]:
for n in range(len(test_dataset)):
    
    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    
    plt.figure(figsize=(21,7), facecolor='white')
    plt.subplot(1,3,1)
    plt.title('original image',fontsize=25,pad=10)
    plt.axis('off')
    plt.imshow(image_vis)
    
    plt.subplot(1,3,2)
    plt.title('label',fontsize=25,pad=10)
    plt.axis('off')
    plt.imshow(cv2.cvtColor(gt_mask,cv2.COLOR_GRAY2RGB))
    
    
    plt.subplot(1,3,3)
    plt.title('prediction',fontsize=25,pad=10)
    plt.axis('off')
    plt.imshow(cv2.cvtColor(pr_mask,cv2.COLOR_GRAY2RGB))
    plt.savefig(os.path.join(sample_compare,'sample_compare_%05d.png' % n))
#     cv2.imwrite(os.path.join(sample_prediction,'%05d.png' % n),cv2.cvtColor(pr_mask,cv2.COLOR_GRAY2RGB))
    
    plt.show()

# Show  and save result

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)
train_logs_df.T

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig(os.path.join(sample_path,'iou_score_plot.png'))
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Dice Loss', fontsize=20)
plt.title('Dice Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig(os.path.join(sample_path,'dice_loss_plot.png'))
plt.show()

In [None]:
print('{} done'.format(original_labels))

In [None]:
train_logs_df.tail()

In [None]:
valid_logs_df.tail()

In [None]:
train_logs_df.to_csv(os.path.join(sample_path,'train_{}_{}epoch.csv'.format(MODEL,EPOCH)))
valid_logs_df.to_csv(os.path.join(sample_path,'valid_{}_{}epoch.csv'.format(MODEL,EPOCH)))