In [None]:
import os

import cv2
import torch
import albumentations as A
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import numpy as np
from torch import nn, optim


In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# DEVICE = 'cpu'
DEVICE

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
TILES_BASE_DIR = "/media/data/local/corn"
SUBDIRECTORIES_TO_PROCESS = [
    "kukurydza_5_ha",
]


UNCROPPED_TILE_SIZE = (512 + 256)  # in pixels
CROPPED_TILE_SIZE = 512
CROP_TILE_MARGIN = (UNCROPPED_TILE_SIZE - CROPPED_TILE_SIZE) // 2

In [None]:
tiles_img_paths = []
tiles_mask_paths = []


for dir_name in SUBDIRECTORIES_TO_PROCESS:
    dir_path = os.path.join(TILES_BASE_DIR, dir_name)
    file_names = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
    
    mask_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'mask' in f])
    img_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'img' in f])
    common_files_prefixes = mask_files_prefixes.intersection(img_files_prefixes)
    all_files_prefixes = mask_files_prefixes.union(img_files_prefixes)
    missing_files_prefixes = all_files_prefixes - common_files_prefixes
    
    if missing_files_prefixes:
        raise Exception(f"Some files don't have correponding pair in mask/image: {missing_files_prefixes}")
    
    
    for file_prefix in common_files_prefixes:
        img_file_name = file_prefix + '_img.png'
        mask_file_name = file_prefix + '_mask.png'
        tiles_img_paths.append(os.path.join(dir_path, img_file_name))
        tiles_mask_paths.append(os.path.join(dir_path, mask_file_name))
        
print(f'Number of tiles = {len(tiles_img_paths)}')

In [None]:
# img = cv2.imread(tiles_img_paths[433])
# mask = cv2.imread(tiles_mask_paths[433], cv2.IMREAD_GRAYSCALE)
# plt.imshow(img)
# plt.show()
# plt.imshow(mask)


In [None]:
# transform = A.Compose([
#     A.HorizontalFlip(p=0.5),
#     A.VerticalFlip(p=0.5),
#     A.RandomScale(scale_limit=0.15),  # above scale 0.16 images are too small
#     A.Rotate(limit=90),  # degrees
#     A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
# ])


# transformed = transform(image=img, mask=mask)
# image_transformed, mask_transformed = transformed['image'], transformed['mask']


# unique, counts = np.unique(mask, return_counts=True)
# dict(zip(unique, counts))

In [None]:
# img.shape
# img.transpose(2, 0, 1).shape

In [None]:
SEGMENTATION_CLASS_VALUES = [0, 255, 127]
NUMBER_OF_SEGMENTATION_CLASSES = len(SEGMENTATION_CLASS_VALUES)

In [None]:
class CornFieldDamageDataset(torch.utils.data.Dataset):
    def __init__(self, img_file_paths, mask_file_paths):
        self.img_file_paths = img_file_paths
        self.mask_file_paths = mask_file_paths
        assert(len(self.img_file_paths) == len(mask_file_paths))
        self._img_and_mask_transform = self._get_img_and_mask_tranform()  # augmentation transform
    
    def __len__(self):
        return len(self.mask_file_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image = cv2.imread(tiles_img_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # not really needed I guess
        mask = cv2.imread(tiles_mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        transformed = self._img_and_mask_transform(image=image, mask=mask)
        image, mask = transformed['image'], transformed['mask']
        
        masks = [(mask == v) for v in SEGMENTATION_CLASS_VALUES]
        mask_stacked = np.stack(masks, axis=0).astype('float')
        
        image = image.astype('float')
        image /= 255
        image = image.transpose(2, 0, 1)
        
        return image.astype('float32'), mask_stacked.astype('float32')
        
    def _get_img_and_mask_tranform(self):
        # Declare an augmentation pipeline
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomScale(scale_limit=0.15),  # above scale 0.16 images are too small
            A.Rotate(limit=90),  # degrees
            A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
        ])
        # TODO - color, contrast, gamma, randomShadow, rain
        return transform


dataset_split_point = int(len(tiles_img_paths)*0.9)
train_dataset = CornFieldDamageDataset(img_file_paths=tiles_img_paths[:dataset_split_point], mask_file_paths=tiles_mask_paths[:dataset_split_point])
valid_dataset = CornFieldDamageDataset(img_file_paths=tiles_img_paths[dataset_split_point:], mask_file_paths=tiles_mask_paths[dataset_split_point:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=8, shuffle=False)

In [None]:
image, mask = train_dataset[222] # get some sample
plt.imshow(mask[:,:,0])
plt.show()
plt.imshow(mask[:,:,1])
plt.show()
plt.imshow(mask[:,:,2])
plt.show()
plt.imshow(image.transpose(1, 2, 0))

In [None]:
model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=NUMBER_OF_SEGMENTATION_CLASSES,  # model output channels (number of classes in your dataset)
    activation='softmax2d',  # ?
)

print(model)


In [None]:
# criterion = nn.CrossEntropyLoss()
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]


# optimizer = optim.SGD(model_fnn.parameters(), lr=0.0001, momentum=0.9)
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])



# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# for xb, yb in valid_loader: 
#     print(xb.shape, yb.shape)

In [None]:
max_score = 0
train_scores = []
valid_scores = []
train_dice_losses = []
valid_dice_losses = []

for i in range(0, 40):
    print(f'\nEpoch: {i}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    train_scores.append(train_logs['iou_score'])
    valid_scores.append(valid_logs['iou_score'])
    train_dice_losses.append(train_logs['dice_loss'])
    valid_dice_losses.append(valid_logs['dice_loss'])

        
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
#         torch.save(model, './best_model.pth')
#         print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

In [None]:
# plt.plot(train_scores)
# plt.plot(valid_scores)

plt.plot(train_dice_losses)
plt.plot(valid_dice_losses)

plt.legend(['train', 'valid'])
plt.xlabel('epoch')
plt.ylabel('iou')
plt.grid()
ax = plt.gca()
ax.set_yscale('log')

In [None]:
model_cpu = model.to('cpu')

In [None]:
vi = iter(valid_loader)


In [None]:
img_batch, mask_batch = next(vi)


In [None]:
img_batch, mask_batch = next(vi)

with torch.no_grad():
    # model_output = model(img_batch.to(DEVICE))
    model_output = model_cpu(img_batch)

for i in range(len(img_batch)):
    plt.imshow(img_batch[i].numpy().transpose([1, 2, 0]))
    plt.show()
    print('ground through')
    plt.imshow(mask_batch[i][1].numpy())
    plt.show()
    print('prediction')
    plt.imshow(model_output[i][1])
    plt.show()
    print("="*30)
