# <center>⭐ Config ⭐</center>

In [None]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

In [None]:
import torch
import torch.nn as nn
import albumentations as A
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
import sys
import cv2

from tqdm.notebook import tqdm

In [None]:
train_data_path = '/kaggle/input/contrail-data-torch/train'
train_label_path = '/kaggle/input/contrail-data-torch/train_labels'

val_data_path = '/kaggle/input/contrail-data-torch/val'
val_label_path = '/kaggle/input/contrail-data-torch/val_labels'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed()

In [None]:
def visualize(**images):
    plt.figure(figsize=(18, 6))
    
    ax = plt.subplot(1, 3, 1)
    ax.imshow(images['image'])
    ax.set_title('False color image')
    
    ax = plt.subplot(1, 3, 2)
    ax.imshow(images['mask'])
    ax.set_title('Ground truth contrail mask')
        
    ax = plt.subplot(1, 3, 3)
    ax.imshow(images['image'])
    ax.imshow(images['mask'], cmap='Reds', alpha=.4, interpolation='none')
    
    plt.show();
    
def one_hot_encode(label, label_values):
    semantic_map = []
    
    for color in label_values:
        equality = np.equal(label, color)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
        
    semantic_map = np.stack(semantic_map, axis=-1)
    
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis=-1)
    return x

def color_codr_segmentation(image, label_values):
    color_codes = np.array(label_values)
    x = color_codes[image.astype(int)]
    
    return x

# <center>⭐ DataAugmentation & DataSet ⭐</center>

In [None]:
# A_transform = [
#     A.OneOf([
#         A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.3), contrast_limit=(-0.3, 0.3), p=1),# 밝기와 대비 변경
#         A.RandomBrightnessContrast(brightness_limit=(-0.8, 0.8), contrast_limit=0, p=1), # 밝기만 변경
#         A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=(-0.8, 0.8), p=1), # 대비만 변경
        
#         # 색상 채도 명도 변경. default(hue_shift_limit=(-20, 20), sat_shift_limit=(-30, 30), val_shift_limit=(-20, 20))
#         A.HueSaturationValue(p=1),
#         # RGB 값을 각각 범위내 임의로 변경 default(r_shift_limit=(-20, 20), g_shift_limit=(-20, 20), b_shift_limit=(-20, 20))
#         A.RGBShift(p=1),
#         # RGB Channel을 랜덤하게 섞음
#         A.ChannelShuffle(p=1),
        
#         # 가우시안 노이즈 분포를 가지는 노이즈를 추가
#         A.GaussNoise(p=1, var_limit=(100, 200)),
#         # 정사각형 노이즈 추가
#         A.Cutout(p=1, num_holes=8, max_h_size=24, max_w_size=24),
        
#         # 히스토그램 균일화 기법인 CLAHE를 이용하여 보다 선명한 이미지 발생
#         A.CLAHE(p=1),
#         # blur_limit가 클수록 더 흐림
#         A.Blur(p=1, blur_limit=(50, 60)),
#     ])
# ]

In [None]:
def get_training_augmentation():
    train_transform = A.Compose([
        # A.Resize(384, 384, interpolation=cv2.INTER_CUBIC),
        # A.ChannelShuffle(p=0.5),
#         A.HueSaturationValue(p=0.5, hue_shift_limit=(0, 20), sat_shift_limit=(0, 30), val_shift_limit=(0, 20)),
#         A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.3), contrast_limit=(-0.3, 0.3), p=0.5),
#         A.GaussNoise(p=0.5, var_limit=(0, 0.001)),
#         A.OneOf([
#             A.Sharpen(p=0.5),
#             A.CLAHE(p=0.5),
#             A.Emboss(p=1),
#         ]),
        A.OneOf(
            [
                A.HorizontalFlip(p=1),
                A.VerticalFlip(p=1),
                A.RandomRotate90(p=1),
            ],p=0.5),
  
        A.PadIfNeeded(min_height=256, min_width=256, always_apply=True, border_mode=0),
        A.Normalize()
    ])
    
    return train_transform

def get_val_augmentation():   
    test_transform = [
        A.PadIfNeeded(min_height=256, min_width=256, always_apply=True, border_mode=0),
        # A.Resize(384, 384, interpolation=cv2.INTER_CUBIC)
        A.Normalize()
    ]
    return A.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1)

def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(A.Lambda(image=preprocessing_fn))
    _transform.append(A.Lambda(image=to_tensor, mask=to_tensor))
        
    return A.Compose(_transform)

In [None]:
class ContrailDataset(Dataset):
    def __init__(self, images_dir, masks_dir, augmentation=None, preprocessing=None):
        super().__init__()
        self.image_path = [os.path.join(images_dir, image_id) for image_id in sorted(os.listdir(images_dir))]
        self.mask_path = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __len__(self):
        return len(self.image_path)
        
    def __getitem__(self, idx):
        image = np.load(self.image_path[idx])
        mask = np.load(self.mask_path[idx])
        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        if self.preprocessing:
            image = np.transpose(image, (2, 0, 1))
            mask = np.transpose(mask, (2, 0, 1))
#             sample = self.preprocessing(image=image, mask=mask)
#             image, mask = sample['image'], sample['mask']
            
        image = image.astype(np.float32)
        mask = mask.astype(np.float32)
            
        return image, mask

In [None]:
for i in range(3):
    prac_train = ContrailDataset(val_data_path, val_label_path)
    random_idx = random.randint(0, len(prac_train) - 1)
    # image, mask = prac_train[random_idx]
    image, mask = prac_train[i]

    visualize(
        image=image,
        mask=mask
    )

    train_ds = ContrailDataset(val_data_path, val_label_path, get_training_augmentation())
    # image, mask = train_ds[random_idx]
    image, mask = train_ds[i]

    visualize(
        image=image,
        mask=mask
    )

# <center>⭐ Model Config & Model ⭐</center>

In [None]:
class Meter(object):
    def reset(self):
        pass

    def add(self, value):
        pass

    def value(self):
        pass

class AverageValueMeter(Meter):
    def __init__(self):
        super(AverageValueMeter, self).__init__()
        self.reset()
        self.val = 0

    def add(self, value, n=1):
        self.val = value
        self.sum += value
        self.var += value * value
        self.n += n

        if self.n == 0:
            self.mean, self.std = np.nan, np.nan
        elif self.n == 1:
            self.mean = 0.0 + self.sum  # This is to force a copy in torch/numpy
            self.std = np.inf
            self.mean_old = self.mean
            self.m_s = 0.0
        else:
            self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
            self.m_s += (value - self.mean_old) * (value - self.mean)
            self.mean_old = self.mean
            self.std = np.sqrt(self.m_s / (self.n - 1.0))

    def value(self):
        return self.mean, self.std

    def reset(self):
        self.n = 0
        self.sum = 0.0
        self.var = 0.0
        self.val = 0.0
        self.mean = np.nan
        self.mean_old = 0.0
        self.m_s = 0.0
        self.std = np.nan

In [None]:
class Epoch:
    def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True):
        self.model = model
        self.loss = loss
        self.metrics = metrics
        self.stage_name = stage_name
        self.verbose = verbose
        self.device = device

        self._to_device()

    def _to_device(self):
        self.model.to(self.device)
        self.loss.to(self.device)
        for metric in self.metrics:
            metric.to(self.device)

    def _format_logs(self, logs):
        str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()]
        s = ", ".join(str_logs)
        return s

    def batch_update(self, x, y):
        raise NotImplementedError

    def on_epoch_start(self):
        pass

    def run(self, dataloader):
        self.on_epoch_start()

        logs = {}
        loss_meter = AverageValueMeter()
        metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}

        with tqdm(
            dataloader,
            desc=self.stage_name,
            file=sys.stdout,
            disable=not (self.verbose),
        ) as iterator:
            for x, y in iterator:
                x, y = x.to(self.device), y.to(self.device)
                loss, y_pred = self.batch_update(x, y)

                # update loss logs
                loss_value = loss.cpu().detach().numpy()
                loss_meter.add(loss_value)
                loss_logs = {self.loss.__name__: loss_meter.mean}
                logs.update(loss_logs)

                # update metrics logs
                for metric_fn in self.metrics:
                    metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
                    metrics_meters[metric_fn.__name__].add(metric_value)
                metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
                logs.update(metrics_logs)

                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)
        return logs

class TrainEpoch(Epoch):
    def __init__(self, model, loss, metrics, optimizer, scheduler, device="cpu", verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name="train",
            device=device,
            verbose=verbose,
        )
        self.scaler = GradScaler()
        self.optimizer = optimizer
        self.scheduler = scheduler

    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, x, y):
        self.optimizer.zero_grad()
        with autocast(): # amp
            prediction = self.model.forward(x)
            loss = self.loss(prediction, y)
            bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(prediction, y)
            loss = loss + bce_loss
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.scheduler.step()
        
        return loss, prediction

class ValidEpoch(Epoch):
    def __init__(self, model, loss, metrics, device="cpu", verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name="valid",
            device=device,
            verbose=verbose,
        )

    def on_epoch_start(self):
        self.model.eval()

    def batch_update(self, x, y):
        with torch.no_grad():
            prediction = self.model.forward(x)
            loss = self.loss(prediction, y)
        return loss, prediction

In [None]:
import segmentation_models_pytorch.utils.metrics
import segmentation_models_pytorch.utils

# ENCODER = 'resnet50'
ENCODER = 'efficientnet-b1'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['Contrail']
ACTIVATION = 'sigmoid'

model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes = len(CLASSES),
    activation = ACTIVATION,
)

# preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

batch_size = 64
steps_per_epoch = int(len(os.listdir(train_data_path)) / batch_size) + 1
TRAINING = True
Epochs = 20
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optim = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001)])
# optim = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=0.01, epochs=Epochs, steps_per_epoch=steps_per_epoch)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=500, eta_min=1e-06, last_epoch=-1)

# model = torch.load('/kaggle/input/contrail-best-model-saved/EffNet_b0_BCE_0.01_Add_Zero.pth', map_location=DEVICE)

In [None]:
train_ds = ContrailDataset(train_data_path, train_label_path, get_training_augmentation(), preprocessing=True)
val_ds = ContrailDataset(val_data_path, val_label_path, get_val_augmentation(), preprocessing=True)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
train_epoch = TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optim,
    scheduler = scheduler,
    device=DEVICE,
    verbose=True
)

val_epoch = ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True
)

# <center>⭐ Train & Val ⭐</center>

In [None]:
if TRAINING:
    
    best_iou_score = 0.0
    train_logs_list, val_logs_list = [], []
    
    for i in tqdm(range(0, Epochs)):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_dl)
        val_logs = val_epoch.run(val_dl)
        print(train_logs)
        print(val_logs)
        train_logs_list.append(train_logs)
        val_logs_list.append(val_logs)
        
        if best_iou_score < val_logs['iou_score']:
            best_iou_score = val_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model Saved!')