<a href="https://colab.research.google.com/github/Vlad-Ozik/rust_segmentation/blob/main/rust_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install segmentation-models-pytorch

In [None]:
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

In [None]:
import os
from typing import Union, List
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# **Prepare data**

# Create new folders for splited images

In [None]:
DATA_DIR = 'gdrive/MyDrive/test_task/rust_dataset/'

In [None]:
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

In [None]:
if not os.path.exists(x_train_dir):
    print('Prepare folders...')
    os.system(f"mkdir {x_train_dir}")
    os.system(f"mkdir {y_train_dir}")
    os.system(f"mkdir {x_valid_dir}")
    os.system(f"mkdir {y_valid_dir}")
    os.system(f"mkdir {x_test_dir}")
    os.system(f"mkdir {y_test_dir}")
    print('Done!')

# Splitting

In [None]:
def split(img: np.array, mask: np.array, split_HW: [int, int]) -> List[list]:
  img_height, img_width = img.shape[:2]
  y, x = 0, 0
  masks_crop = list()
  images_crop = list()
  for i in range(img_height//split_HW[0]):
      for j in range(img_width//split_HW[1]):      
          mask_crop = mask[y:y+split_HW[0], x:x+split_HW[1]]
          # if mask is exist for crop
          if mask_crop.max() == 255:
            img_crop = img[y:y+split_HW[0], x:x+split_HW[1]]
            y, x = split_HW[0]*i, split_HW[1]*j
            masks_crop.append(mask_crop)
            images_crop.append(img_crop)
          else:
            y, x = split_HW[0]*i, split_HW[1]*j
  return images_crop, masks_crop

In [None]:
image_names = [name.split('.')[0] for name in os.listdir(DATA_DIR+'image/')]
try:
  # split images 80% train, 15% test, 5% val
  image_names = np.split(image_names, [int(.8*len(image_names)), int(.85*len(image_names))])
  folders = [['train', 'trainannot'],
             ['val', 'valannot'],
             ['test', 'testannot']]
  for i, imgs_name in enumerate(image_names):
    for im_name in imgs_name:
      print(im_name)
      img = cv2.imread(os.path.join(DATA_DIR, 'image', im_name+".JPG"))
      mask = cv2.imread(os.path.join(DATA_DIR, 'mask', im_name+".png"))
      images, masks = split(img, mask, [600, 800])
      j = 0
      for image, mask in zip(images, masks):
        cv2.imwrite(os.path.join(DATA_DIR, folders[i][0], im_name+f"_{j}.png"), image)
        cv2.imwrite(os.path.join(DATA_DIR, folders[i][1], im_name+f"_{j}_mask.png"), mask)
        j += 1

except FileNotFoundError as e:
  print(f"Image Not Found! {e}")
finally:
  print("Done!")

# DataLoader and augmentations

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    def __init__(
            self, 
            images_dir, 
            masks_dir,
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i].replace('.png', '_mask.png'), cv2.IMREAD_GRAYSCALE)
        
        gb_mask = cv2.GaussianBlur(mask, (3,3), 7)
        thrash = cv2.threshold(gb_mask, 3, 255, cv2.THRESH_BINARY)[1]
        kernel = np.ones((5,5),np.uint8)
        thrash = cv2.morphologyEx(thrash, cv2.MORPH_OPEN, kernel)
        thrash = cv2.morphologyEx(thrash, cv2.MORPH_CLOSE, kernel)
        thrash[thrash>0] = 1
        mask = thrash[:, :, np.newaxis].astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
dataset = Dataset(x_valid_dir, y_valid_dir)

image, mask = dataset[1] # get some sample

visualize(
    image=image, 
    cars_mask=mask.squeeze(),
)

In [None]:
import albumentations as albu


In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.CropNonEmptyMaskIfExists(height=320, width=320, always_apply=True),

        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(384, 480, always_apply= True),
    ]
    return albu.Compose(test_transform)

In [None]:
augmented_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=get_training_augmentation(), 
)

# same image with different random transforms
for i in range(3):
    image, mask = augmented_dataset[i]
    visualize(image=image, mask=mask.squeeze())

# **Training**

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['rust',]
ACTIVATION = 'sigmoid'
DEVICE = 'cuda'

model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
)

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=1)

In [None]:
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.RAdam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
image, mask = train_dataset[1]
mask.shape

In [None]:
max_score = 0

for i in range(0, 40):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model_rust_seg.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

# **Test best saved model**

In [None]:
best_model = torch.load('/content/best_model_rust_seg.pth')

In [None]:
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)


In [None]:
test_dataloader = DataLoader(test_dataset)

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

In [None]:
test_dataset_vis = Dataset(
    x_test_dir, y_test_dir, 
)

In [None]:
for i in range(5):
    n = np.random.choice(len(test_dataset))
    
    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        
    visualize(
        image=image_vis, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )