In [None]:
import os
import random
from glob import glob
from tqdm import tqdm

import cv2
import monai
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.core.transforms_interface import DualTransform

In [None]:
device = 'cuda:0'

config = {
    'epochs': 30,
    'width': 960,
    'height': 544,
    'batch_size': 4,
    'num_workers': 6,
    'result_dir': 'result',
    'lr': 0.001,
    'seed': 1919
}

In [None]:
torch.cuda.manual_seed(config['seed'])
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])
random.seed(config['seed'])

In [None]:
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7) -> float:
    intersection = np.sum(prediction * ground_truth)
    return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)

In [None]:
paths = glob(f'./oba/images/*.png')

class OBA(DualTransform):
    def apply(self, img, n, x, y, **params):
        try:
            if len(img.shape) == 3:
                obj_image = cv2.imread(paths[n])
                obj_image = cv2.cvtColor(obj_image, cv2.COLOR_BGR2RGB)
                obj_mask = cv2.imread(paths[n].replace('images', 'masks'), cv2.IMREAD_GRAYSCALE)

                h, w, _ = obj_image.shape
                img_h, img_w, _ = img.shape
                x = min(x, img_w - w)
                y = min(y, img_h - h)

                img_ = img[y:y+h,x:x+w]
                img_ = cv2.copyTo(obj_image, obj_mask, img_)
                img[y:y+h,x:x+w] = img_

                return img
            
            else:
                obj_mask = cv2.imread(paths[n].replace('images', 'masks'), cv2.IMREAD_GRAYSCALE)

                h, w = obj_mask.shape
                img_h, img_w = img.shape

                x = min(x, img_w - w)
                y = min(y, img_h - h)

                img[y:y+h, x:x+w] = obj_mask[0:h, 0:w]
                img[img != 0] = 1

                return img
        except:
            return img
    
    def get_params(self):
        return {
            'n': np.random.randint(len(paths)),
            'x': np.random.randint(50, 700),
            'y': np.random.randint(50, 300)
        }

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, label_paths, transform):
        super().__init__()

        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform

    
    def __len__(self):
        return len(self.image_paths)
    

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        path = self.image_paths[idx].replace('JPEGImages_pos', 'SegmentationClass').replace('jpg', 'png')
        mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)

        if image.shape[:2] != mask.shape:
            mask = mask.reshape(image.shape[0], image.shape[1])

        mask[mask != 0] = 1
        augmented = self.transform(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']

        return image, mask

In [None]:
train_transform = A.Compose([
    A.Rotate(limit=90, p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),

    A.OneOf([
        A.Compose([
            A.RandomCrop(int(config['height']*1.9), int(config['width']*1.9)),
            A.Resize(config['height'], config['width'])
        ]),
        A.Compose([
            A.RandomCrop(int(config['height']*1.5), int(config['width']*1.5)),
            A.Resize(config['height'], config['width'])
        ]),
        A.Compose([
            A.LongestMaxSize(max_size=config['width']),
            A.PadIfNeeded(min_height=config['height'], min_width=config['width'], border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(height=config['height'], width=config['width']),
        ]),
        A.Compose([
            A.LongestMaxSize(max_size=1800),
            A.PadIfNeeded(min_height=config['height'], min_width=1800, border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(config['height'], config['width'], p=1)
        ]),
        A.RandomCrop(config['height'], config['width'], p=1),
    ], p=1),

    OBA(p=1),
    OBA(p=1),
    OBA(p=1),
    A.Normalize(),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.OneOf([
        A.Compose([
            A.RandomCrop(int(config['height']*1.9), int(config['width']*1.9)),
            A.Resize(config['height'], config['width'])
        ]),
        A.Compose([
            A.RandomCrop(int(config['height']*1.5), int(config['width']*1.5)),
            A.Resize(config['height'], config['width'])
        ]),
        A.Compose([
            A.LongestMaxSize(max_size=config['width']),
            A.PadIfNeeded(min_height=config['height'], min_width=config['width'], border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(height=config['height'], width=config['width']),
        ]),
        A.Compose([
            A.LongestMaxSize(max_size=1800),
            A.PadIfNeeded(min_height=config['height'], min_width=1800, border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(config['height'], config['width'], p=1)
        ]),
        A.RandomCrop(config['height'], config['width'], p=1),
    ], p=1),
    A.Normalize(),
    ToTensorV2(),
])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_train_paths = sorted(glob('./FLC2019/trainval/JPEGImages_pos/*.jpg'))
label_train_paths = sorted(glob('./FLC2019/trainval/SegmentationClass/*.png'))

image_test_paths = sorted(glob('./FLC2019/test/JPEGImages_pos/*.jpg'))
label_test_paths = sorted(glob('./FLC2019/test/SegmentationClass/*.png'))

In [None]:
print(f'train image counts : {len(image_train_paths)}')
print(f'train true label counts : {len(label_train_paths)}')
print(f'test image counts : {len(image_test_paths)}')
print(f'test true label counts : {len(label_test_paths)}')

In [None]:
train_dataset = CustomDataset(image_train_paths, label_train_paths, transform=train_transform)
test_dataset = CustomDataset(image_test_paths, label_test_paths, transform=test_transform)

train_dataloader = DataLoader(dataset=train_dataset, 
                            batch_size=config['batch_size'],
                            num_workers=config['num_workers'], 
                            shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset, 
                            batch_size=config['batch_size'], 
                            num_workers=config['num_workers'], 
                            shuffle=False)

In [None]:
model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
).to(device)


# model = smp.DeepLabV3Plus(
#     encoder_name="efficientnet-b0",
#     encoder_weights="imagenet",
#     in_channels=3,
#     classes=1,
# ).to(device)

# loss_fn = monai.losses.DiceLoss()
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=config['lr'])

In [None]:
for epoch in range(config['epochs']):
    train_losses, val_losses, train_scores, val_scores = [], [], [], []
    
    model.train()
    for images, labels in tqdm(train_dataloader):
        images = images.to(device)
        labels = labels.type(torch.cuda.FloatTensor)

        optimizer.zero_grad()
        outputs = model(images)

        seg_prob = torch.sigmoid(outputs).detach().cpu().numpy().squeeze()
        seg = (seg_prob > 0.5).astype(np.uint8)
        score = dice_score(seg, labels.detach().cpu().numpy())
        train_scores.append(score)

        loss = loss_fn(outputs, labels.unsqueeze(1))
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.type(torch.cuda.FloatTensor)

            outputs = model(images)

            loss = loss_fn(outputs, labels.unsqueeze(1))
            seg_prob = torch.sigmoid(outputs).detach().cpu().numpy().squeeze()
            seg = (seg_prob > 0.5).astype(np.uint8)

            score = dice_score(seg, labels.detach().cpu().numpy())
            val_scores.append(score)

            val_losses.append(loss.item())

        print(f'[{epoch}/{config["epochs"]-1}], train_loss: {np.mean(train_losses)}, val_loss: {np.mean(val_losses)}, train_dice: {np.mean(train_scores)}, val_dice: {np.mean(val_scores)}')

In [None]:
torch.save(model.state_dict(), './weights/unetpp_4leaf.pth')

In [None]:
image = cv2.imread('./FLC2019/test/1_000007.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image = image[950:950+config['height'], 2300:2300+config['width']]

In [None]:
test_transform = A.Compose([
    A.Normalize(),
    ToTensorV2(),
])

In [None]:
image_ = test_transform(image=image)
image_ = image_['image']
image_ = image_.unsqueeze(0)
image_ = image_.to(device)

In [None]:
outputs = model(image_)

In [None]:
seg_prob = torch.sigmoid(outputs).detach().cpu().numpy().squeeze()
seg = (seg_prob > 0.5).astype(np.uint8)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image)
axes[1].imshow(seg)
plt.show()