## GPU 스펙 확인

In [None]:
!nvidia-smi

## 캐글 데이터 경로 확인

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## 폴더 경로 설정

In [None]:
workspace_path = '/kaggle/input/clouds-segmentation2024spring'  # 본인의 파일 경로 반영

In [None]:
!pip install segmentation_models_pytorch

In [None]:
!pip install albumentations==0.4.6
!pip install yacs

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import warnings
import random
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import segmentation_models_pytorch.utils as utils
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils import base
from segmentation_models_pytorch.utils import functional as F
from segmentation_models_pytorch.base.modules import Activation
from segmentation_models_pytorch.utils.metrics import IoU
import albumentations as albu
import torch.nn as nn
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as T

## patch

In [29]:
def save_patches(image, mask, patch_size, stride, save_dir, base_name):
    img_height, img_width = image.shape[:2]
    patch_id = 0

    for y in range(0, img_height - patch_size + 1, stride):
        for x in range(0, img_width - patch_size + 1, stride):
            img_patch = image[y:y + patch_size, x:x + patch_size]
            mask_patch = mask[y:y + patch_size, x:x + patch_size]

            img_patch_path = os.path.join(save_dir, 'ngr', f"{base_name}_{patch_id}.png")
            mask_patch_path = os.path.join(save_dir, 'label', f"{base_name}_{patch_id}.png")

            cv2.imwrite(img_patch_path, img_patch)
            cv2.imwrite(mask_patch_path, mask_patch)

            patch_id += 1

def patch(data_dir, patch_dir):
    patch_size = 224
    stride = 194
    save_dir = f'{patch_dir}/{patch_size}_{stride}'  # 패치가 저장될 디렉토리토리

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(os.path.join(save_dir, 'label')):
        os.makedirs(os.path.join(save_dir, 'label'))
    if not os.path.exists(os.path.join(save_dir, 'ngr')):
        os.makedirs(os.path.join(save_dir, 'ngr'))

    image_files = [f for f in os.listdir(os.path.join(data_dir, 'ngr')) if f.endswith('.png')]
    mask_files = [f for f in os.listdir(os.path.join(data_dir, 'label')) if f.endswith('.png')]

    for img_file, mask_file in tqdm(zip(image_files, mask_files), total=len(image_files)):
        image = cv2.imread(os.path.join(data_dir, 'ngr', img_file))
        mask = cv2.imread(os.path.join(data_dir, 'label', mask_file))

        base_name = os.path.splitext(img_file)[0]
        save_patches(image, mask, patch_size, stride, save_dir, base_name)

In [None]:
data_dir = f'{workspace_path}/train'

patch_dir = '/kaggle/working/cache'
os.makedirs(patch_dir, exist_ok=True)

patch(data_dir, patch_dir)

## seg_train

In [None]:
warnings.filterwarnings("ignore")
    
seed = 42
num_workers = 0

patch_size = 224
patch_stride = 194

data_dir = f'{patch_dir}/{patch_size}_{patch_stride}/ngr'
mask_dir = f'{patch_dir}/{patch_size}_{patch_stride}/label'

batch_size = 32
epochs = 1

learning_rate = 0.001
weight_decay = 0

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False

set_seed(seed)

class Dataset(BaseDataset):

    def __init__(
            self, 
            images_fps, 
            masks_fps, 
            augmentation=None, 
            preprocessing=None,
            classes=None, 
            palette=None,
            add_cloud_agumentation_TF=False,
    ):
        self.images_fps = images_fps
        self.masks_fps = masks_fps

        self.CLASSES = classes
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        self.PALETTE = palette

        for i in range(len(self.masks_fps)):
            self.mask_ids = np.unique(cv2.imread(self.masks_fps[i], 0))[1:]
            if len(self.mask_ids) == len(self.class_values):
                break

        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.add_cloud_agumentation_TF = add_cloud_agumentation_TF

    def __getitem__(self, i):

        image = cv2.imread(self.images_fps[i])
        mask = cv2.imread(self.masks_fps[i])

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        # PALETTE를 사용하여 마스크를 클래스 인덱스로 변환
        mask_class = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.float32)
        for color, class_idx in self.PALETTE.items():
            mask_class[np.all(mask == color, axis=-1)] = class_idx

        # 특정 클래스 마스크 추출
        masks = [(mask_class == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.images_fps)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
        albu.RandomRotate90(p=0.5),
    ]

    return albu.Compose(train_transform)

def get_validation_augmentation():
    test_transform = [
    ]
    return albu.Compose(test_transform)

ENCODER = 'timm-efficientnet-b0'
ENCODER_WEIGHTS = 'imagenet' # 'imagenet', 'pre-trained from:..', None
CLASSES = ['background', 'thick_cloud', 'thin_cloud', 'cloud_shadow']
PALETTE = {
    (0, 0, 0): 0, # background
    (255, 0, 0): 1,  # thick_cloud
    (0, 255, 0): 2,  # thin_cloud
    (255, 255, 0): 3  # cloud_shadow
}
ACTIVATION = 'softmax'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    in_channels=3,
    classes=len(CLASSES), 
    activation=ACTIVATION,
).to(DEVICE)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.png')]
mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')]

train_images, valid_images, train_masks, valid_masks = train_test_split(image_files, mask_files, test_size=0.2, random_state=42)

train_dataset = Dataset(
    train_images, 
    train_masks, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
    palette=PALETTE,
)

valid_dataset = Dataset(
    valid_images, 
    valid_masks, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
    palette=PALETTE,
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, weight_a=0.5, weight_b=0.5):
        super(CombinedLoss, self).__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.weight_a = weight_a
        self.weight_b = weight_b
        self.__name__ = loss_a.__class__.__name__ + '_' + loss_b.__class__.__name__ 

    def forward(self, output, target):
        return self.weight_a * self.loss_a(output, target) + self.weight_b * self.loss_b(output, target)

class DiceScore(base.Metric):
    __name__ = "dice_score"

    def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.threshold = threshold
        self.activation = Activation(activation)
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        return F.f_score(
            y_pr,
            y_gt,
            eps=self.eps,
            beta=1.0,  
            threshold=self.threshold,
            ignore_channels=self.ignore_channels,
        )

DiceLoss = utils.losses.DiceLoss()
CE_Loss = torch.nn.CrossEntropyLoss()
combined_criterion = CombinedLoss(DiceLoss, CE_Loss, weight_a=0.5, weight_b=0.5)

metrics = [
    DiceScore(),
    IoU(),
]

optimizer = torch.optim.AdamW([
    dict(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay),
])

train_epoch = utils.train.TrainEpoch(
    model, 
    loss=combined_criterion, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model, 
    loss=combined_criterion, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

# save_dir = os.path.join('./', 'ckpt')
save_dir = os.path.join('/kaggle/working/', 'ckpt')
os.makedirs(save_dir, exist_ok=True)

dataset = f'aug_{patch_size}_{patch_stride}'

max_dice_score = 0

for epoch in range(epochs):
    print('\nEpoch: {}'.format(epoch))
    
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    print('train dice score:', train_logs['dice_score'], 'train iou score:', train_logs['iou_score'])
    print('valid dice score:', valid_logs['dice_score'], 'valid iou score:', valid_logs['iou_score'])

    if valid_logs['dice_score'] > max_dice_score:
        max_dice_score = valid_logs['dice_score']
        torch.save(model, os.path.join(save_dir, f'{dataset}_best_model.pth'))

# ============================ Augmentation 없는 모델 훈련 (Ensamble) ===========================================        

def get_training_augmentation():
    train_transform = []

    return albu.Compose(train_transform)

train_dataset = Dataset(
    train_images, 
    train_masks, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
    palette=PALETTE,
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    in_channels=3,
    classes=len(CLASSES), 
    activation=ACTIVATION,
).to(DEVICE)

optimizer = torch.optim.AdamW([
    dict(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay),
])

train_epoch = utils.train.TrainEpoch(
    model, 
    loss=combined_criterion, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model, 
    loss=combined_criterion, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

dataset = f'non_aug_{patch_size}_{patch_stride}'

max_dice_score = 0

for epoch in range(epochs):
    print('\nEpoch: {}'.format(epoch))
    
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    print('train dice score:', train_logs['dice_score'], 'train iou score:', train_logs['iou_score'])
    print('valid dice score:', valid_logs['dice_score'], 'valid iou score:', valid_logs['iou_score'])

    if valid_logs['dice_score'] > max_dice_score:
        max_dice_score = valid_logs['dice_score']
        torch.save(model, os.path.join(save_dir, f'{dataset}_best_model.pth'))