In [None]:
import glob
import cv2
import numpy as np
import os

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
import torchvision
import segmentation_models_pytorch as smp

import albumentations as albu
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

### Helper function

In [None]:
def visualzie(**images):
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.title("".join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# データの拡張用
def get_training_augmentation():
    IMAGE_SIZE = 256
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, border_mode=0),
        albu.PadIfNeeded(min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, always_apply=True, border_mode=0),
        albu.RandomCrop(height=IMAGE_SIZE, width=IMAGE_SIZE, always_apply=True),
        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),

        albu.OneOf([
            albu.CLAHE(p=1),
            albu.RandomBrightness(p=1),
            albu.RandomGamma(p=1),
        ], p=0.9),

        albu.OneOf([
            albu.IAASharpen(p=1),
            albu.Blur(blur_limit=3, p=1),
            albu.MotionBlur(blur_limit=3, p=1),
        ], p=0.9),

        albu.OneOf([
            albu.RandomContrast(p=1),
            albu.HueSaturationValue(p=1)
        ], p=0.9)
    ]

    return albu.Compose(train_transform)

# テンソル化
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

# 前処理
def get_preproessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def crop_to_square(image):
    size = min(image.size)
    left, upper = (image.width - size) // 2, (image.height - size) // 2
    right, bottom = (image.width + size) // 2, (image.height + size) // 2
    return image.crop((left, upper, right, bottom))

### Dataset class

In [None]:
# Old version

# class Dataset(BaseDataset):
#     CLASSES=[ 'sugarcane', 'weed']
    
#     def __init__(self, images_dir, masks_dir, classes, augmentation=None, preprocessing=None):
#         self.ids = os.listdir(images_dir)
#         self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
#         self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
#         # self.class_values =  [classes.index(cls.lower()) for cls in classes]
#         self.class_values =  classes
#         self.augmentation = augmentation
#         self.preprocessing = preprocessing

#     def __getitem__(self, i):

#         image = Image.open(self.images_fps[i])
#         image = crop_to_square(image)
#         image = image.resize((128, 128), Image.ANTIALIAS)
#         image = np.asarray(image)

#         masks = Image.open(self.masks_fps[i])
#         masks = crop_to_square(masks)
#         masks = masks.resize((128, 128), Image.ANTIALIAS)
#         masks = np.asarray(masks)

#         masks = np.where(masks == 255, 21, masks)

#         cls_idx = [self.CLASSES.index(cls) for cls in self.class_values]
#         masks = [(masks == idx) for idx in cls_idx]
#         mask = np.stack(masks, axis=-1).astype("float32")

#         # # 画像データ
#         # image = cv2.imread(self.images_fps[i])
#         # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#         # # マスク画像データ
#         # mask = cv2.imread(self.masks_fps[i], 0)
#         # masks = [(mask == v) for v in self.class_values]
#         # mask = np.stack(masks, axis=-1).astype('float32')

#         # データの拡張
#         if self.augmentation:
#             sample = self.augmentation(image=image, mask=mask)
#             image, mask = sample['image'], sample['mask']
            
#         if self.preprocessing:
#             sample = self.preprocessing(image=image, mask=mask)
#             image, mask = sample['image'], sample['mask']

#         return image, mask

#     def __len__(self):
#         return len(self.ids)


In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)

    """

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]

    CLASSES=[ 'sugarcane', 'weed']

    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)


        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        image = t(image)
        mask = torch.from_numpy(mask).long()

        return image, mask

    def __len__(self):
        return len(self.ids)

In [None]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = [ 'sugarcane', 'weed']
ACTIVATION = 'sigmoid'

device = 'cuda'
decoder = 'DeepLabV3'

# model = smp.Unet(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     encoder_depth=4,
#     decoder_channels=(128, 64, 32, 16),
#     classes=len(CLASSES),
#     activation=ACTIVATION
# )

# model = smp.UnetPlusPlus(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     encoder_depth=4,
#     decoder_channels=(128, 64, 32, 16),
#     classes=len(CLASSES),
#     activation=ACTIVATION
# )

# model = smp.PSPNet(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     classes=len(CLASSES),
#     activation=ACTIVATION
# )

# model = smp.DeepLabV3Plus(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     classes=len(CLASSES),
#     activation=ACTIVATION
# )

model = smp.DeepLabV3(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION
)

model = model.to(device)

In [None]:
model.parameters()

In [None]:
train_dir = "train"
val_dir = "val"

if not os.path.exists(train_dir):
    os.mkdir(train_dir)
    os.mkdir(train_dir + '/images')
    os.mkdir(train_dir + '/masks')

if not os.path.exists(val_dir):
    os.mkdir(val_dir)
    os.mkdir(val_dir + '/images')
    os.mkdir(val_dir + '/masks')

In [None]:
train_dir = './train'
val_dir = './val'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)


train_dataset = Dataset(
    os.path.join(train_dir, 'images'),
    os.path.join(train_dir, 'masks'),
    augmentation=get_training_augmentation(),
    # preprocessing=get_preproessing(preprocessing_fn),
    classes=CLASSES
)

valid_dataset = Dataset(
    os.path.join(val_dir, 'images'),
    os.path.join(val_dir, 'masks'),
    augmentation=get_training_augmentation(),
    # preprocessing=get_preproessing(preprocessing_fn),
    classes=CLASSES
)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0, drop_last=True)

In [None]:
dataset = Dataset(os.path.join(train_dir, 'images'), os.path.join(train_dir, 'masks'), classes=[ 'sugarcane', 'weed'])

image, mask = dataset[5]

print(image.shape)
print(mask.shape)

visualzie(image=image.permute(1, 2, 0), mask=mask)

### Training

In [None]:
# とりあえず適当な指標を入れてみる

metrics = [smp.utils.metrics.IoU(threshold=0.5)]
loss = smp.utils.losses.DiceLoss()
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.001),
])


In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=device,
    verbose=True
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=device,
    verbose=True
)

In [None]:
max_score = 0

patience = 5
early_stop_counter = 0

epoch = 50
for i in range(epoch):
    print(f'\nEpoch: {i}')
    # try:
    #     train_logs = train_epoch.run(train_loader)
    #     val_logs = valid_epoch.run(valid_loader)
    # except Exception as e:
    #     print(e)
    #     break
    train_logs = train_epoch.run(train_loader)
    val_logs = valid_epoch.run(valid_loader)
    
    if max_score < val_logs['iou_score']:
        max_score = val_logs['iou_score']
        torch.save(model, f'./model/{decoder}_{ENCODER}.pth')
    #     print('Model saved!')
    #     early_stop_counter = 0
    # else:
    #     early_stop_counter += 1
    #     print(f"not improve for {early_stop_counter} Epoch")
    #     if early_stop_counter==patience:
    #         print(f"early stop. Max Score {max_score}")
    #         break

    if i == 30:
        optimizer.param_groups[0]['lr'] = 1e-4
        print('Decrease decoder learning rate to 1e-4!')


In [None]:
print(f"max_score: {max_score:.2f}")

### Check model

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset as BaseDataset
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as albu

import time
import os
from tqdm.notebook import tqdm

from torch.utils.data import Dataset as BaseDataset
import segmentation_models_pytorch as smp
import glob

In [None]:
CLASSES=[ 'sugarcane', 'weed']

In [None]:
class testDataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES=[ 'sugarcane', 'weed']

    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)


        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        #t = T.Compose([T.ToTensor()])
        #image = t(image)
        mask = torch.from_numpy(mask).long()

        return image, mask

    def __len__(self):
        return len(self.ids)

In [None]:
test_dataset = testDataset(
    os.path.join('./val/' 'images'),
    os.path.join('./val/', 'masks'),
    # augmentation=get_training_augmentation(),
    # preprocessing=get_preproessing(preprocessing_fn),
    classes=CLASSES
)

test_dataloader = DataLoader(test_dataset)

In [None]:
val_files = glob.glob('./val/images/*.png')
f = val_files[9]

palette_image = Image.open(glob.glob('./val/masks/*.png')[9])
PALETTE = palette_image.getpalette()

In [None]:
best_model = torch.load("./model/DeepLabV3_resnet34.pth")

In [None]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

In [None]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=len(CLASSES)):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes):
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0:
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        
        return np.nanmean(iou_per_class)

In [None]:
def predict_image_mask_miou(model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    model.eval()
    t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])

    # print(t)
    # print(image.shape)

    # image = image.transpose(1, 2, 0).astype("float32")

    # print(image.shape)

    image = t(image)
    model.to(device); image=image.to(device)
    mask = mask.to(device)

    with torch.no_grad():

        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)

        output = model(image)
        score = mIoU(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    
    return masked, score

def predict_image_mask_pixel(model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    model.eval()
    t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
    image = t(image)
    model.to(device)
    image = image.to(device)
    mask = mask.to(device)

    with torch.no_grad():
        
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)

        output = model(image)
        acc = pixel_accuracy(output, mask)
        masked = torch.argmax(output, dim=3)
        masked = masked.cpu().squeeze(0)

    return masked, acc

def miou_score(model, test_set):
    score_iou = []
    for i in tqdm(range(len(test_set))):
        img, mask = test_set[i]
        pred_mask, score = predict_image_mask_miou(model, img, mask)
        score_iou.append(score)
    return score_iou

In [None]:
mob_miou = miou_score(best_model, test_dataset)

In [None]:
print("test set miou", np.mean(mob_miou))

In [None]:
def pixel_acc_mean(model, test_set):
    accuracy = []
    for i in tqdm(range(len(test_set))):
        img, mask = test_set[i]
        pred_mask, acc = predict_image_mask_pixel(model, img, mask)
        accuracy.append(acc)
    return accuracy

def pixel_acc(model, image, mask):
    pred_mask, acc = predict_image_mask_pixel(model, image, mask)
    return acc

mob_acc = pixel_acc_mean(best_model, test_dataset)
print(f"Test Set Pixel accuracy {np.mean(mob_acc) * 100:.2f} %")

In [None]:
for i in range(len(os.listdir('./val/images/'))):

    # image2, mask2 = test_dataset[i]
    image2, mask2 = test_dataloader.dataset[i]
    pred_mask2, score2 = predict_image_mask_miou(best_model, image2, mask2)
    accuracy = pixel_acc(best_model, image2, mask2)

    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(20,10))
    ax1.imshow(image2)
    ax1.set_title('Picture')

    mask = Image.fromarray(np.uint8(mask2), mode="P")
    mask.putpalette(PALETTE)

    ax2.imshow(mask)
    ax2.set_title('Ground truth')
    ax2.set_axis_off()

    pred_mask2 = Image.fromarray(np.uint8(pred_mask2), mode="P")
    pred_mask2.putpalette(PALETTE)

    ax3.imshow(pred_mask2)
    ax3.set_title(f'UNet-MobileNet | mIoU {score2:.3f} | acc {accuracy * 100 :.2f}%')
    ax3.set_axis_off()

In [None]:
def check_prediction(n):

    # img, mask = valid_dataset[n]
    img, mask = valid_loader.dataset[n]

    fig, ax = plt.subplots(1, 3, tight_layout=True)
    
    ax[0].imshow(img.permute(1, 2, 0))

    # mask = np.argmax(mask, axis=0)
    mask = Image.fromarray(np.uint8(mask), mode="P")
    mask.putpalette(PALETTE)

    ax[1].imshow(mask)

    x = torch.tensor(img).unsqueeze(0)

    print(x.shape)

    y = best_model(x.to(device))
    y = y[0].cpu().detach().numpy()
    y = np.argmax(y, axis=0)

    predict_class_img = Image.fromarray(np.uint8(y), mode="P")
    predict_class_img.putpalette(PALETTE)
    ax[2].imshow(predict_class_img)

    plt.show()

In [None]:
best_model = torch.load("./model/unet_plus_plus_resnet34.pth")
best_model.eval()

paths = os.listdir('./val/masks/')

# 検証データから"cat","person"を含む画像を取得
idx_dict = {"both":[],"sugarcane":[],"weed":[]}

# 該当の対象物があればpathをリストに加える
for i in range(len(paths)):

    img = np.asarray(Image.open(f"./val/masks/{paths[i]}"))
    unique_class = np.unique(img)

    if 0 in unique_class and 1 in unique_class:
        idx_dict["both"].append(i)
        
    elif 0 in unique_class:
        idx_dict["sugarcane"].append(i)
        
    elif 1 in unique_class:
        idx_dict["weed"].append(i)

In [None]:
# ラベル毎に実行して結果を確認
for label, idx_list in idx_dict.items():
    print("="*30 , label, "="*30)
    for i, idx in enumerate(idx_list):
        check_prediction(idx)
        if i==2:
            break