In [None]:
import os
from dataclasses import dataclass, field
from typing import List
import random
import copy

import cv2
import torch
import albumentations as A
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import numpy as np
from torch import nn, optim


In [None]:
CPU_DEVICE = 'cpu'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# DEVICE = 'cpu'
DEVICE

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
IS_BIG_STRIDE = True  # if using images with big stride (768 instead of 256 pixels). Faster training and validation/test on random samples from all files


if IS_BIG_STRIDE:
    TILES_BASE_DIR = "/media/data/local/corn/processed_stride768"
    SUBDIRECTORIES_TO_PROCESS_TRAIN = [
        "kukurydza_5_ha",
#         "kukurydza_10_ha",
        "kukurydza_11_ha",
        "kukurydza_13_ha",
        "kukurydza_15_ha",
        "kukurydza_18_ha",
        "kukurydza_25_ha",
        "kukurydza_38_ha",
        "kukurydza_60_ha",
    ]


    # if TEST or VALIDATIONS are empty, random part of training set will be used
    SUBDIRECTORIES_TO_PROCESS_VALID = [
    ]

    SUBDIRECTORIES_TO_PROCESS_TEST = [
    ]
#     SUBDIRECTORIES_TO_PROCESS_TRAIN = [
#         "kukurydza_5_ha",
# #         "kukurydza_10_ha",
# #         "kukurydza_11_ha",
#         "kukurydza_13_ha",
#         "kukurydza_15_ha",
#         "kukurydza_18_ha",
#         "kukurydza_25_ha",
#         "kukurydza_38_ha",
#         "kukurydza_60_ha",
#     ]


#     # if TEST or VALIDATIONS are empty, random part of training set will be used
#     SUBDIRECTORIES_TO_PROCESS_VALID = [
#         "kukurydza_11_ha",
#     ]

#     SUBDIRECTORIES_TO_PROCESS_TEST = [
#         "kukurydza_10_ha",    
#     ]


UNCROPPED_TILE_SIZE = (512 + 256)  # in pixels
CROPPED_TILE_SIZE = 512
CROP_TILE_MARGIN = (UNCROPPED_TILE_SIZE - CROPPED_TILE_SIZE) // 2

In [None]:
@dataclass
class TilesPaths:
    img_paths: List = field(default_factory=lambda:[])
    mask_paths: List = field(default_factory=lambda:[])

    def split_into_train_valid_test(self, percentage_for_train=0.8):
        N = len(self.img_paths)
        test_percentage = (1 - percentage_for_train) / 2
        sp = [int(N*percentage_for_train), int(N*(1-test_percentage))]  # dataset split points
        tile_paths_train = TilesPaths(img_paths=self.img_paths[:sp[0]], mask_paths=self.mask_paths[:sp[0]])
        tile_paths_valid = TilesPaths(img_paths=self.img_paths[sp[0]:sp[1]], mask_paths=self.mask_paths[sp[0]:sp[1]])
        tile_paths_test = TilesPaths(img_paths=self.img_paths[sp[1]:], mask_paths=self.mask_paths[sp[1]:])
        return tile_paths_train, tile_paths_valid, tile_paths_test
    
    def __add__(self, other):
        new = copy.deepcopy(self)
        new.img_paths += other.img_paths
        new.mask_paths += other.mask_paths
        return new
    
    def shuffle(self):
        c = list(zip(self.img_paths, self.mask_paths))
        random.shuffle(c)
        self.img_paths, self.mask_paths = zip(*c)

    
def get_tile_paths_for_directories_with_split(dir_names):
    tile_paths_train = TilesPaths()
    tile_paths_valid = TilesPaths()
    tile_paths_test = TilesPaths()
    for dir_name in dir_names:
        all_tile_paths = get_tile_paths_for_directories(dir_names=[dir_name])
        new_train, new_valid, new_test = all_tile_paths.split_into_train_valid_test()
        tile_paths_train += new_train
        tile_paths_valid += new_valid
        tile_paths_test += new_test
    
    tile_paths_train.shuffle()
    tile_paths_valid.shuffle()
    tile_paths_test.shuffle()
    
    return tile_paths_train, tile_paths_valid, tile_paths_test
    
        
def get_tile_paths_for_directories(dir_names, shuffle=True) -> TilesPaths:
    tile_paths = TilesPaths()
    for dir_name in dir_names:
        dir_path = os.path.join(TILES_BASE_DIR, dir_name)
        file_names = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]

        mask_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'mask' in f])
        img_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'img' in f])
        common_files_prefixes = mask_files_prefixes.intersection(img_files_prefixes)
        all_files_prefixes = mask_files_prefixes.union(img_files_prefixes)
        missing_files_prefixes = all_files_prefixes - common_files_prefixes

        if missing_files_prefixes:
            raise Exception(f"Some files don't have correponding pair in mask/image: {missing_files_prefixes} in {dir_path}")

        common_files_prefixes = list(common_files_prefixes)
        if shuffle:
            random.shuffle(common_files_prefixes)
        for file_prefix in common_files_prefixes:
            img_file_name = file_prefix + '_img.png'
            mask_file_name = file_prefix + '_mask.png'
            tile_paths.img_paths.append(os.path.join(dir_path, img_file_name))
            tile_paths.mask_paths.append(os.path.join(dir_path, mask_file_name))
    return tile_paths


if SUBDIRECTORIES_TO_PROCESS_VALID and SUBDIRECTORIES_TO_PROCESS_TEST:
    # we have valid tiles for test/valid
    tile_paths_train = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_TRAIN)
    tile_paths_valid = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_VALID)
    tile_paths_test = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_TEST)
else:
    tile_paths_train, tile_paths_valid, tile_paths_test = get_tile_paths_for_directories_with_split(SUBDIRECTORIES_TO_PROCESS_TRAIN)


print(f'Number of tiles train = {len(tile_paths_train.img_paths)}')
print(f'Number of tiles validation = {len(tile_paths_valid.img_paths)}')
print(f'Number of tiles test = {len(tile_paths_test.img_paths)}')


In [None]:
SEGMENTATION_CLASS_VALUES = [0, 255, 127]
NUMBER_OF_SEGMENTATION_CLASSES = len(SEGMENTATION_CLASS_VALUES)

In [None]:
class CornFieldDamageDataset(torch.utils.data.Dataset):
    def __init__(self, img_file_paths, mask_file_paths, augment=True):
        self.img_file_paths = img_file_paths
        self.mask_file_paths = mask_file_paths
        assert(len(self.img_file_paths) == len(mask_file_paths))
        if augment:
            self._img_and_mask_transform = self._get_img_and_mask_augmentation_tranform()  # augmentation transform
        else:
            self._img_and_mask_transform = self._get_img_and_mask_crop_tranform()  # crop only transform
    
    def __len__(self):
        return len(self.mask_file_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image = cv2.imread(self.img_file_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # not really needed I guess
        mask = cv2.imread(self.mask_file_paths[idx], cv2.IMREAD_GRAYSCALE)

        transformed = self._img_and_mask_transform(image=image, mask=mask)
        image, mask = transformed['image'], transformed['mask']
        
        masks = [(mask == v) for v in SEGMENTATION_CLASS_VALUES]
        mask_stacked = np.stack(masks, axis=0).astype('float')
        
        image = image.astype('float')
        image /= 255
        image = image.transpose(2, 0, 1)
        
        return image.astype('float32'), mask_stacked.astype('float32')
        
    def _get_img_and_mask_augmentation_tranform(self):
        # Declare an augmentation pipeline
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomScale(scale_limit=0.15),  # above scale 0.16 images are too small
            A.Rotate(limit=90),  # degrees
            # TODO normalize instead divide by 255?
            A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
            # TODO ToTensorV2 instead of manual stacking and transpoition?
        ])
        # TODO - color, contrast, gamma, randomShadow, rain
        return transform

    def _get_img_and_mask_crop_tranform(self):
        transform = A.Compose([
            A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
        ])
        return transform

train_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_train.img_paths, mask_file_paths=tile_paths_train.mask_paths)
valid_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_valid.img_paths, mask_file_paths=tile_paths_valid.mask_paths, augment=False)
test_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_test.img_paths, mask_file_paths=tile_paths_test.mask_paths, augment=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=6, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=6, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=6, shuffle=True)

In [None]:
len(train_loader)

In [None]:
# image, mask = train_dataset[222] # get some sample
# plt.imshow(mask[0, :, :])
# plt.show()
# plt.imshow(mask[1, :, :])
# plt.show()
# plt.imshow(mask[2, :, :])
# plt.show()
# plt.imshow(image.transpose(1, 2, 0))

In [None]:
model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=NUMBER_OF_SEGMENTATION_CLASSES,  # model output channels (number of classes in your dataset)
    activation='softmax2d',  # ?
)

print(model)


In [None]:
# criterion = nn.CrossEntropyLoss()  # class imbalance is typically taken care of simply by assigning loss multipliers to each class,
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5, name='IoU'),
    smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[1, 2], name='IoU-0'),
    smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[0, 2], name='IoU-1'),
    smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[0, 1], name='IoU-2'),
    smp.utils.metrics.Fscore(threshold=0.5, ignore_channels=[2]),
    smp.utils.metrics.Accuracy(threshold=0.5, ignore_channels=[2]),
    smp.utils.metrics.Recall(threshold=0.5, ignore_channels=[2]),
    smp.utils.metrics.Precision(threshold=0.5, ignore_channels=[2]),
]


# optimizer = optim.SGD(model_fnn.parameters(), lr=0.0001, momentum=0.9)
optimizer = torch.optim.Adam([ 
#     dict(params=model.parameters(), lr=(0.0001)),  # 0.0001  #   0.000003 gives 80 epoch for 768 stride
    dict(params=model.parameters(), lr=(0.000005 if IS_BIG_STRIDE else 0.00001)),  # 0.0001  #   0.000003 gives 80 epoch for 768 stride
])


train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

for e in [valid_epoch, train_epoch]:
    e.metrics[1].__name__="IoU_Class0"
    e.metrics[2].__name__="IoU_Class1"
    e.metrics[3].__name__="IoU_Class2"

In [None]:
max_score = 0
train_logs_vec = []
valid_logs_vec = []
best_model = None

In [None]:
number_of_epochs = 55 if IS_BIG_STRIDE else 4
epoch_to_decrease_learning_rate = 40 if IS_BIG_STRIDE else 2

for i in range(0, number_of_epochs):
    print(f'\nEpoch: {i}')
    train_logs_vec.append(train_epoch.run(train_loader))
    valid_logs = valid_epoch.run(valid_loader)
    valid_logs_vec.append(valid_logs)

    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        model.to(CPU_DEVICE)
        best_model = copy.deepcopy(model)
        model.to(DEVICE)
        
    if i == epoch_to_decrease_learning_rate:
        optimizer.param_groups[0]['lr'] /= 2
        print(f"Decrease decoder learning rate to {optimizer.param_groups[0]['lr']}")

In [None]:
for metric in valid_logs_vec[0].keys():
    train_metric_vec = [m[metric] for m in train_logs_vec]
    valid_metric_vec = [m[metric] for m in valid_logs_vec]
    
    plt.plot(train_metric_vec)
    plt.plot(valid_metric_vec)

    plt.legend(['train', 'valid'])
    plt.xlabel('epoch')
    plt.ylabel(metric)
    plt.grid()
    plt.show()
#     ax = plt.gca()
#     ax.set_yscale('log')

In [None]:
valid_logs_vec[-1]

In [None]:
model = best_model
model.to(DEVICE)

test_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

test_epoch.run(test_loader)

In [None]:
model = model.to('cpu')                 
torch.save(model.state_dict(), '/media/data/local/corn/processed_stride768//model_cpu__trained_without_10ha')

In [None]:
vi = iter(test_loader)
# vi = iter(train_loader)

for i in range(8):  # increase to get more images
    img_batch, mask_batch = next(vi)

    with torch.no_grad():
    #     model_output = model(img_batch.to(DEVICE))
        model_output = model(img_batch)


    columns = 5
    rows = len(img_batch)
    fig = plt.figure(figsize=(columns * 4, rows * 4))


    for i in range(len(img_batch)):
        fig.add_subplot(rows, columns, 1 + i*columns + 0)   
        plt.imshow(img_batch[i].numpy().transpose([1, 2, 0]))
        plt.axis('off')
        plt.title('img')

        fig.add_subplot(rows, columns, 1 + i*columns + 1)   
        plt.imshow(mask_batch[i][1].numpy())
        plt.axis('off')
        plt.title('original damage mask')

        fig.add_subplot(rows, columns, 1 + i*columns + 2)   
        plt.imshow(model_output[i][1])
        plt.axis('off')
        plt.title('prediction damage')

        fig.add_subplot(rows, columns, 1 + i*columns + 3)   
        cax = plt.imshow(model_output[i][1] - mask_batch[i][1], vmin=-1.1, vmax=1.1)
        plt.title('damage diff (predict-gt)')
        plt.axis('off')
        cbar = fig.colorbar(cax, ticks=[-1, 0, 1])
        cbar.ax.set_yticklabels(['false negative', 'true', 'false positive'])

        fig.add_subplot(rows, columns, 1 + i*columns + 4)   
        plt.imshow(model_output[i][0])
        plt.title('prediction healty field')
        plt.axis('off')


    plt.show()

In [None]:
device = 'cpu'  
# device = DEVICE

model = model.to(device)
number_of_batches = len(test_loader)


healthy_field_ground_truth_pix = 0
damage_ground_truth_pix = 0

healthy_field_predicted_pix = 0
damage_field_predicted_pix = 0

damage_prediction_true_positives_pix = 0


healthy_intersection_pix = 0
healthy_union_pix = 0

damage_intersection_pix = 0
damage_union_pix = 0



for i, (img_batch, mask_batch) in enumerate(test_loader):
    print(f'Batch {i} / {number_of_batches}')
    with torch.no_grad():
        model_output = model(img_batch.to(device)).to(CPU_DEVICE)
    
    for i in range(model_output.shape[0]):
        ground_truth_healthy_field = mask_batch[i, 0, :, :].numpy().astype(int)
        ground_truth_damage = mask_batch[i, 1, :, :].numpy().astype(int)

        predicted_healty_field = model_output[i, 0, :, :].numpy()
        predicted_damage = model_output[i, 1, :, :].numpy()
        predicted_healty_field = np.where(predicted_healty_field > 0.5, 1, 0)
        predicted_damage = np.where(predicted_damage > 0.5, 1, 0)


        healthy_field_ground_truth_pix += np.count_nonzero(ground_truth_healthy_field)
        damage_ground_truth_pix += np.count_nonzero(ground_truth_damage)

        healthy_field_predicted_pix += np.count_nonzero(predicted_healty_field)
        damage_field_predicted_pix += np.count_nonzero(predicted_damage)

        common_damage = np.logical_and(ground_truth_damage, predicted_damage)
        damage_prediction_true_positives_pix += np.count_nonzero(common_damage)
        
        common_healthy = np.logical_and(ground_truth_healthy_field, predicted_healty_field)
        damage_intersection_pix += np.count_nonzero(common_damage)
        healthy_intersection_pix += np.count_nonzero(common_healthy)
        damage_union_pix += np.count_nonzero(np.logical_or(ground_truth_damage, predicted_damage))
        healthy_union_pix += np.count_nonzero(np.logical_or(ground_truth_healthy_field, predicted_healty_field))

        

    
total_ground_truth_pix = healthy_field_ground_truth_pix + damage_ground_truth_pix
total_predicted_pix = healthy_field_predicted_pix + damage_field_predicted_pix

iou_damage = damage_intersection_pix / damage_union_pix
iou_healthy = healthy_intersection_pix / healthy_union_pix

print(f'healthy_field_ground_truth = {healthy_field_ground_truth_pix / total_ground_truth_pix * 100:.2f} %')
print(f'damage_ground_truth = {damage_ground_truth_pix / total_ground_truth_pix * 100:.2f} %')

print(f'healthy_field_predicted = {healthy_field_predicted_pix / total_predicted_pix * 100:.2f} %')
print(f'damage_field_predicted = {damage_field_predicted_pix / total_predicted_pix * 100:.2f} %')

print(f'damage_prediction_true_positives/damage_field_predicted = {damage_prediction_true_positives_pix / damage_field_predicted_pix * 100:.2f} %')

print(f'iou_damage = {iou_damage:.3f}')
print(f'iou_healthy = {iou_healthy:.3f}')


