# 0. 초기 세팅

In [2]:
import os
import cv2
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Pytorch에서 gpu를 사용하는 방법.
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print('Current cuda device ', torch.cuda.current_device())

Current cuda device  4


In [3]:
from matplotlib import pyplot as plt
import tifffile as tiff
from PIL import Image
import random

In [4]:
import tensorflow as tf
from tensorflow import keras
from keras.metrics import MeanIoU

In [5]:
import random

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

seed_everything(42) # Seed 고정

# 1. 데이터 준비

In [6]:
# RLE 디코딩 함수
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)


# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return " ".join(str(x) for x in runs)

In [None]:
df_test = pd.read_csv('./test.csv')

In [None]:
train_img_dir = './pro_data/data_for_training_and_testing/train/images/'
train_mask_dir = './pro_data/data_for_training_and_testing/train/masks/'

valid_img_dir = './pro_data/data_for_training_and_testing/val/images/'
valid_mask_dir = './pro_data/data_for_training_and_testing/val/masks/'

In [None]:
data_train_img = sorted(os.listdir(train_img_dir))
data_train_mask = sorted(os.listdir(train_mask_dir))

data_val_img = sorted(os.listdir(valid_img_dir))
data_val_mask = sorted(os.listdir(valid_mask_dir))

In [None]:
df_train = pd.DataFrame({'img_path': data_train_img, 'mask_path': data_train_mask})
df_valid = pd.DataFrame({'img_path': data_val_img, 'mask_path': data_val_mask})

`1` Quick understanding of the dataset

In [None]:
df_train, df_test = pd.read_csv('./train.csv'), pd.read_csv('./test.csv')

In [None]:
temp_img = cv2.imread(df_train.loc[0, 'img_path']) #3 channels / spectral bands
plt.imshow(temp_img[:,:,2]) #View each channel...
temp_mask = rle_decode(df_train.loc[0, 'mask_rle'], shape = (1024, 1024)) #1 channels
labels, count = np.unique(temp_mask[:, :], return_counts=True) #Check for each channel. All chanels are identical
print("Labels are: ", labels, " and the counts are: ", count)

`2` Now, crop each large image into patches of 224x224. Save them into a directory

so we can use data augmentation and read directly from the drive.

In [None]:
patch_size = 224

In [None]:
for i in range(len(df_train)):
    img_path = df_train.loc[i, 'img_path']
    image = cv2.imread(img_path)       #Read each image as BGR
    SIZE_X = (image.shape[1] // patch_size) * patch_size  #Nearest size divisible by our patch size
    SIZE_Y = (image.shape[0] // patch_size) * patch_size  #Nearest size divisible by our patch size
    image = Image.fromarray(image)
    image = image.crop((0 ,0, SIZE_X, SIZE_Y))
    image = np.array(image)

    #Extract patches from each image
    print("Now patchifying image:", img_path)
    patches_img = patchify(image, (224, 224, 3), step=224)

    for j in range(patches_img.shape[0]):
        for k in range(patches_img.shape[1]):

            single_patch_img = patches_img[j,k,:,:]
            single_patch_img = single_patch_img[0] #Drop the extra unecessary dimension that patchify adds.

            cv2.imwrite("224_patches/images/"+ df_train.loc[i, 'img_id'] + "patch_"+str(j)+str(k)+".png", single_patch_img)

In [None]:
for i in range(len(df_train)):
    mask = rle_decode(df_train.loc[i, 'mask_rle'], shape = (1024, 1024))
    SIZE_X = (mask.shape[1] // patch_size) * patch_size  #Nearest size divisible by our patch size
    SIZE_Y = (mask.shape[0] // patch_size) * patch_size  #Nearest size divisible by our patch size
    mask = Image.fromarray(mask)
    mask = mask.crop((0 ,0, SIZE_X, SIZE_Y))
    mask = np.array(mask)

    #Extract patches from each image
    print("Now patchifying mask:", i)
    patches_mask = patchify(mask, (224, 224), step=224)

    for j in range(patches_mask.shape[0]):
        for k in range(patches_mask.shape[1]):

            single_patch_mask = patches_mask[j,k,:,:]

            cv2.imwrite("./224_patches/masks/"+ "MASK_" + df_train.loc[i, 'img_id'][-4:] + "patch_"+str(j)+str(k)+".png", single_patch_mask)

In [None]:
image_test = cv2.imread("./224_patches/images/TRAIN_0000patch_01.png", 1)
image_test = cv2.cvtColor(image_test, cv2.COLOR_BGR2RGB)
mask_test = cv2.imread("./224_patches/masks/MASK_0000patch_01.png", 0)

In [None]:
plt.imshow(image_test)

In [None]:
plt.imshow(mask_test)

In [None]:
train_img_dir = "224_patches/images/"
train_mask_dir = "224_patches/masks/"

img_list = os.listdir(train_img_dir)
msk_list = os.listdir(train_mask_dir)

num_images = len(img_list)

In [None]:
print(len(img_list), len(msk_list))

In [None]:
img_num = random.randint(0, num_images-1)

img_for_plot = cv2.imread(train_img_dir + img_list[img_num], 1)
img_for_plot = cv2.cvtColor(img_for_plot, cv2.COLOR_BGR2RGB)

mask_for_plot = cv2.imread(train_mask_dir + msk_list[img_num], 0)

plt.figure(figsize=(12, 8))
plt.subplot(121)
plt.imshow(img_for_plot)
plt.title('Image')
plt.subplot(122)
plt.imshow(mask_for_plot, cmap='gray')
plt.title('Mask')
plt.show()

### Custom Dataset

In [None]:
class SatelliteDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [None]:
class SatelliteDatasetForValid(Dataset):
    def __init__(self, dataset, transform=None, infer=False):
        self.data = dataset
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [None]:
class SatelliteDatasetForTest(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.infer:
            if self.transform:
                image = sr.upsample(image)
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            image = sr.upsample(image)
            augmented = self.transform(image=image, mask=mask)

            image = augmented['image']
            mask = augmented['mask']

        return image, mask

# 1. 데이터 정의

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
# # Data Loader
transform = A.Compose([
    # A.Resize(224, 224),
    A.Normalize(),
    ToTensorV2()
])

data = pd.read_csv("./train.csv")
train, valid = train_test_split(data, test_size=0.2, random_state=123)

trainset = SatelliteDatasetForValid(dataset = train, transform=transform)
validset = SatelliteDatasetForValid(dataset = valid, transform=transform)

train_dataloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(validset, batch_size=8, shuffle=False, num_workers=4)

#### 실험

In [None]:
img_path = train['img_path'][0]

In [None]:
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
image.shape

In [None]:
import numbers
import numpy as np
from numpy.lib.stride_tricks import as_strided


def view_as_windows(arr_in, window_shape, step=1):

    # -- basic checks on arguments
    if not isinstance(arr_in, np.ndarray):
        raise TypeError("`arr_in` must be a numpy ndarray")

    ndim = arr_in.ndim

    if isinstance(window_shape, numbers.Number):
        window_shape = (window_shape,) * ndim
    if not (len(window_shape) == ndim):
        raise ValueError("`window_shape` is incompatible with `arr_in.shape`")

    if isinstance(step, numbers.Number):
        if step < 1:
            raise ValueError("`step` must be >= 1")
        step = (step,) * ndim
    if len(step) != ndim:
        raise ValueError("`step` is incompatible with `arr_in.shape`")

    arr_shape = np.array(arr_in.shape)
    window_shape = np.array(window_shape, dtype=arr_shape.dtype)

    if ((arr_shape - window_shape) < 0).any():
        raise ValueError("`window_shape` is too large")

    if ((window_shape - 1) < 0).any():
        raise ValueError("`window_shape` is too small")

    # -- build rolling window view
    slices = tuple(slice(None, None, st) for st in step)
    window_strides = np.array(arr_in.strides)

    indexing_strides = arr_in[slices].strides

    win_indices_shape = (
        (np.array(arr_in.shape) - np.array(window_shape)) // np.array(step)
    ) + 1

    new_shape = tuple(list(win_indices_shape) + list(window_shape))
    strides = tuple(list(indexing_strides) + list(window_strides))
    print(strides)

    arr_out = as_strided(arr_in, shape=new_shape, strides=strides)
    return arr_out

In [None]:
from typing import Tuple, Union, cast
import numpy as np


Imsize = Union[Tuple[int, int], Tuple[int, int, int]]


def patchify(image: np.ndarray, patch_size: Imsize, step: int = 1) -> np.ndarray:
    """
    Split a 2D or 3D image into small patches given the patch size.

    Parameters
    ----------
    image: the image to be split. It can be 2d (m, n) or 3d (k, m, n)
    patch_size: the size of a single patch
    step: the step size between patches

    Examples
    --------
    >>> image = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
    >>> patches = patchify(image, (2, 2), step=1)  # split image into 2*3 small 2*2 patches.
    >>> assert patches.shape == (2, 3, 2, 2)
    >>> reconstructed_image = unpatchify(patches, image.shape)
    >>> assert (reconstructed_image == image).all()
    """
    return view_as_windows(image, patch_size, step)

In [None]:
patches = patchify(image, (224, 224, 3), step=100)

In [None]:
patches.shape

In [None]:
(patches[0][1][0]).shape

In [None]:
(patches[0][0][0]).shape

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(image)
plt.show()

In [None]:
data224 = [patches[i][j][0] for i in range(9) for j in range(9)]

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize = (100, 100))
rows = 27
cols = 3

for i in range(81):
    ax = fig.add_subplot(rows, cols, i+1)
    ax.imshow(cv2.cvtColor(data224[i], cv2.COLOR_BGR2RGB))
    ax.set_xticks([]), ax.set_yticks([])

plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(patches[8][2][0])
plt.show()

In [None]:
mask0 = rle_decode(train['mask_rle'][0], shape = (1024, 1024))

In [None]:
patches_mask = patchify(mask0, (224, 224), step=100)

In [None]:
patches_mask.shape

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(patches_mask[8][2])
plt.show()

# 2. 데이터 학습

### 0) import

In [None]:
!pip install einops
!pip install timm

In [None]:
import torch.nn as nn
import torchvision.models

import torch
from torch import Tensor
import torch.nn.functional as F
from einops import rearrange, repeat

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import timm

### 1) Frist Model

In [None]:
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels),
            nn.ReLU6()
        )


class ConvBN(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
        super(ConvBN, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels)
        )


class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
        super(Conv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
        )


class SeparableConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
                 norm_layer=nn.BatchNorm2d):
        super(SeparableConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            norm_layer(out_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.ReLU6()
        )


class SeparableConvBN(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
                 norm_layer=nn.BatchNorm2d):
        super(SeparableConvBN, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            norm_layer(out_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )


class SeparableConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
        super(SeparableConv, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)
        self.drop = nn.Dropout(drop, inplace=True)

    def forward(self, x):
        # x = x.permute(0,3,1,2)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class GlobalLocalAttention(nn.Module):
    def __init__(self,
                 dim=256,
                 num_heads=16,
                 qkv_bias=False,
                 window_size=8,
                 relative_pos_embedding=True
                 ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // self.num_heads
        self.scale = head_dim ** -0.5
        self.ws = window_size

        self.qkv = Conv(dim, 3*dim, kernel_size=1, bias=qkv_bias)
        self.local1 = ConvBN(dim, dim, kernel_size=3)
        self.local2 = ConvBN(dim, dim, kernel_size=1)
        self.proj = SeparableConvBN(dim, dim, kernel_size=window_size)

        self.attn_x = nn.AvgPool2d(kernel_size=(window_size, 1), stride=1,  padding=(window_size//2 - 1, 0))
        self.attn_y = nn.AvgPool2d(kernel_size=(1, window_size), stride=1, padding=(0, window_size//2 - 1))

        self.relative_pos_embedding = relative_pos_embedding

        if self.relative_pos_embedding:
            # define a parameter table of relative position bias
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.ws)
            coords_w = torch.arange(self.ws)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.ws - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.ws - 1
            relative_coords[:, :, 0] *= 2 * self.ws - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

            trunc_normal_(self.relative_position_bias_table, std=.02)

    def pad(self, x, ps):
        _, _, H, W = x.size()
        # print(f"W: {W}, H: {H}, ps: {ps}, W % ps: {W % ps}, H % ps: {H % ps}")
        if W % ps != 0:
            x = F.pad(x, (0, ps - W % ps), mode='reflect')
        if H % ps != 0:
            x = F.pad(x, (0, 0, 0, ps - H % ps), mode='reflect')
        return x

    def pad_out(self, x):
        x = F.pad(x, pad=(0, 1, 0, 1), mode='reflect')
        return x

    def forward(self, x):
        # x = x.permute(0,3,1,2)
        B, C, H, W = x.shape

        local = self.local2(x) + self.local1(x)

        x = self.pad(x, self.ws)
        B, C, Hp, Wp = x.shape
        qkv = self.qkv(x)

        q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads,
                            d=C//self.num_heads, hh=Hp//self.ws, ww=Wp//self.ws, qkv=3, ws1=self.ws, ws2=self.ws)

        dots = (q @ k.transpose(-2, -1)) * self.scale

        if self.relative_pos_embedding:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.ws * self.ws, self.ws * self.ws, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            dots += relative_position_bias.unsqueeze(0)

        attn = dots.softmax(dim=-1)
        attn = attn @ v

        attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads,
                         d=C//self.num_heads, hh=Hp//self.ws, ww=Wp//self.ws, ws1=self.ws, ws2=self.ws)


        attn = attn[:, :, :H, :W]

        out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \
              self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect'))

        out = out + local
        out = self.pad_out(out)
        out = self.proj(out)
        # print(out.size())
        out = out[:, :, :H, :W]

        return out


class Block(nn.Module):
    def __init__(self, dim=256, num_heads=16,  mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, window_size=8):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = GlobalLocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, window_size=window_size)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)

    def forward(self, x):
        # x = x.permute(0,3,1,2)
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class WF(nn.Module):
    def __init__(self, in_channels=128, decode_channels=128, eps=1e-8):
        super(WF, self).__init__()
        self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = eps
        self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

    def forward(self, x, res):
        # x = x.permute(0,3,1,2)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
        x = self.post_conv(x)
        return x


class FeatureRefinementHead(nn.Module):
    def __init__(self, in_channels=64, decode_channels=64):
        super().__init__()
        self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = 1e-8
        self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

        self.pa = nn.Sequential(nn.Conv2d(decode_channels, decode_channels, kernel_size=3, padding=1, groups=decode_channels),
                                nn.Sigmoid())
        self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                Conv(decode_channels, decode_channels//16, kernel_size=1),
                                nn.ReLU6(),
                                Conv(decode_channels//16, decode_channels, kernel_size=1),
                                nn.Sigmoid())

        self.shortcut = ConvBN(decode_channels, decode_channels, kernel_size=1)
        self.proj = SeparableConvBN(decode_channels, decode_channels, kernel_size=3)
        self.act = nn.ReLU6()

    def forward(self, x, res):
        # x = x.permute(0,3,1,2)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
        x = self.post_conv(x)
        shortcut = self.shortcut(x)
        pa = self.pa(x) * x
        ca = self.ca(x) * x
        x = pa + ca
        x = self.proj(x) + shortcut
        x = self.act(x)

        return x


class AuxHead(nn.Module):
    def __init__(self, in_channels=64, num_classes=1):
        super().__init__()
        self.conv = ConvBNReLU(in_channels, in_channels)
        self.drop = nn.Dropout(0.1)
        self.conv_out = Conv(in_channels, num_classes, kernel_size=1)

    def forward(self, x, h, w):
        # x = x.permute(0,3,1,2)
        feat = self.conv(x)
        feat = self.drop(feat)
        feat = self.conv_out(feat)
        feat = F.interpolate(feat, size=(h, w), mode='bilinear', align_corners=False)
        return feat


class Decoder(nn.Module):
    def __init__(self,
                 encoder_channels=(64, 128, 256, 512),
                 decode_channels=64,
                 dropout=0.1,
                 window_size=8,
                 num_classes=1):
        super(Decoder, self).__init__()

        self.pre_conv = ConvBN(encoder_channels[-1], decode_channels, kernel_size=1)
        self.b4 = Block(dim=decode_channels, num_heads=8, window_size=window_size)

        self.b3 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
        self.p3 = WF(encoder_channels[-2], decode_channels)

        self.b2 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
        self.p2 = WF(encoder_channels[-3], decode_channels)

        if self.training:
            self.up4 = nn.UpsamplingBilinear2d(scale_factor=4)
            self.up3 = nn.UpsamplingBilinear2d(scale_factor=2)
            self.aux_head = AuxHead(decode_channels, num_classes)

        self.p1 = FeatureRefinementHead(encoder_channels[-4], decode_channels)

        self.segmentation_head = nn.Sequential(ConvBNReLU(decode_channels, decode_channels),
                                               nn.Dropout2d(p=dropout, inplace=True),
                                               Conv(decode_channels, num_classes, kernel_size=1))
        self.init_weight()

    def forward(self, res1, res2, res3, res4, h, w):
        if self.training:
            x = self.b4(self.pre_conv(res4))
            h4 = self.up4(x)

            x = self.p3(x, res3)
            x = self.b3(x)
            h3 = self.up3(x)

            x = self.p2(x, res2)
            x = self.b2(x)
            h2 = x
            x = self.p1(x, res1)
            x = self.segmentation_head(x)
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)

            ah = h4 + h3 + h2
            ah = self.aux_head(ah, h, w)

            return x, ah
        else:
            x = self.b4(self.pre_conv(res4))
            x = self.p3(x, res3)
            x = self.b3(x)

            x = self.p2(x, res2)
            x = self.b2(x)

            x = self.p1(x, res1)

            x = self.segmentation_head(x)
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)

            return x

    def init_weight(self):
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

class UNetFormer(nn.Module):
    def __init__(self,
                 decode_channels=64,
                 dropout=0.1,
                 backbone_name='swsl_resnet18',
                 pretrained=True,
                 window_size=8,
                 num_classes=1
                 ):
        super().__init__()

        self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32,
                                          out_indices=(1, 2, 3, 4), pretrained=pretrained)
        encoder_channels = self.backbone.feature_info.channels()
        self.decoder = Decoder(encoder_channels, decode_channels, dropout, window_size, num_classes)

    def forward(self, x):
        # x = x.permute(0,3,1,2)
        h, w = x.size()[-2:]
        res1, res2, res3, res4 = self.backbone(x)
        if self.training:
            x, ah = self.decoder(res1, res2, res3, res4, h, w)
            return x, ah
        else:
            x = self.decoder(res1, res2, res3, res4, h, w)
            return x

In [None]:
import torch

seg_model = torch.load("./model(unetf1).pth", map_location=device)

In [None]:
from torchsummary import summary

summary(seg_model, input_size=(3, 1024, 1024))

### 2) Second Model - Gan (upsampling)

#### 실험

In [None]:
from IPython.display import Image, display

In [None]:
test_data = pd.read_csv('./test.csv')
img_path = (test_data['img_path'])[0]
display(Image(filename = img_path))

In [None]:
import PIL

In [None]:
def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return PIL.Image.fromarray(tensor)

In [None]:
tmp_transform = A.Compose([A.Resize(1024, 1024)])

test_image = cv2.imread(img_path)
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
result = tmp_transform(image = test_image)['image']

tensor_to_image(result)


#### 모델 정의

In [None]:
import cv2
from cv2 import dnn_superres

In [None]:
img_path

In [None]:
# Create an SR object
sr = dnn_superres.DnnSuperResImpl_create()

# Read image
image = cv2.imread(img_path)

# Read the desired model
path = "FSRCNN_x4.pb"
sr.readModel(path)

# Set the desired model and scale to get correct pre- and post-processing
sr.setModel("fsrcnn", 4)

# Upscale the image
result = sr.upsample(image)
print(result)

# Save the image
cv2.imwrite("./upscaled.png", result)

In [None]:
display(Image(filename = './upscaled.png'))

In [None]:
# Upscale the image
result = sr.upsample(result)
print(result)

# Save the image
cv2.imwrite("./upscaled.png", result)

In [None]:
tmp2_transform = A.Compose([
    A.Resize(1024, 1024),
])

In [None]:
result = tmp2_transform(image = result)['image']

# Save the image
cv2.imwrite("./upscaled.png", result)

In [None]:
display(Image(filename = './upscaled.png'))

In [None]:
display(Image(filename = './upscaled.png'))

### 3) Model 학습

In [None]:
# # model 초기화
model = UNetFormer().to(device)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# training loop
for epoch in range(8):  # 10 에폭 동안 학습합니다.
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_dataloader):
        images = images.float().to(device)
        masks = masks.float().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs[0], masks.unsqueeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(train_dataloader)}')

### 4) valid (성능 평가를 위한)

In [None]:
with torch.no_grad():
    seg_model.eval()
    result = []
    for images, mask in tqdm(valid_dataloader):
        images = images.float().to(device)

        outputs = seg_model(images)
        masks = torch.sigmoid(outputs).cpu().numpy()
        masks = np.squeeze(masks, axis=1)
        masks = (masks > 0.35).astype(np.uint8) # Threshold = 0.35

        for i in range(len(images)):
            mask_rle = rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                result.append(-1)
            else:
                result.append(mask_rle)

### 5) 예측

In [None]:
class SatelliteDatasetForTest(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.infer:
            if self.transform:
                image = sr.upsample(image)
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            image = sr.upsample(image)
            augmented = self.transform(image=image, mask=mask)

            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [None]:
test_transform = A.Compose([
    A.Resize(1024, 1024),
    A.Normalize(),
    ToTensorV2()
])

In [None]:
test_dataset = SatelliteDataset(csv_file='./test.csv', transform=test_transform, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8)

In [None]:
with torch.no_grad():
    seg_model.eval()
    result = []
    for images in tqdm(test_dataloader):
        images = images.float().to(device)

        outputs = seg_model(images)
        masks = torch.sigmoid(outputs).cpu().numpy()
        masks = np.squeeze(masks, axis=1)
        masks = (masks > 0.35).astype(np.uint8) # Threshold = 0.35

        for i in range(len(images)):
            mask_rle = rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                result.append(-1)
            else:
                result.append(mask_rle)

## 4) 성능 평가

### true_mask vs pred_mask 이미지 비교

In [None]:
import numpy as np
import pandas as pd
from typing import List, Union
from joblib import Parallel, delayed


def rle_decode(mask_rle: Union[str, int], shape=(224, 224)) -> np.array:
    '''
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    if mask_rle == -1:
        return np.zeros(shape)

    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

In [None]:
lst_input_img_path = list(valid['img_path'])

In [None]:
result[1]

In [None]:
test_data['img_path']

In [None]:
import matplotlib.pyplot as plt

def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        if i == 0:
            img = cv2.imread(display_list[i])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = img.astype(np.uint8).copy()
        if i == 1:
            img = rle_decode(display_list[i], shape = (1024, 1024)) # shape 설정
        plt.imshow(img)
        plt.axis('off')
    plt.show()

display_list = [test_data['img_path'][9], result[9]]
display(display_list)

### Dice Score

In [None]:
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7) -> float:
    '''
    Calculate Dice Score between two binary masks.
    '''
    intersection = np.sum(prediction * ground_truth)
    return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)


def calculate_dice_scores(ground_truth_df, prediction_df, img_shape=(224, 224)) -> List[float]:
    '''
    Calculate Dice scores for a dataset.
    '''


    # Keep only the rows in the prediction dataframe that have matching img_ids in the ground truth dataframe
    prediction_df = prediction_df[prediction_df.iloc[:, 0].isin(ground_truth_df.iloc[:, 0])]
    prediction_df.index = range(prediction_df.shape[0])


    # Extract the mask_rle columns
    pred_mask_rle = prediction_df.iloc[:, 1]
    gt_mask_rle = ground_truth_df.iloc[:, 1]


    def calculate_dice(pred_rle, gt_rle):
        pred_mask = rle_decode(pred_rle, img_shape)
        gt_mask = rle_decode(gt_rle, img_shape)


        if np.sum(gt_mask) > 0 or np.sum(pred_mask) > 0:
            return dice_score(pred_mask, gt_mask)
        else:
            return None  # No valid masks found, return None


    dice_scores = Parallel(n_jobs=-1)(
        delayed(calculate_dice)(pred_rle, gt_rle) for pred_rle, gt_rle in zip(pred_mask_rle, gt_mask_rle)
    )


    dice_scores = [score for score in dice_scores if score is not None]  # Exclude None values


    return np.mean(dice_scores)

In [None]:
# ground_truth_df = valid.drop('img_path', axis=1)
df = valid.drop('img_path', axis=1)

In [None]:
# valid_pred = {'img_id': ground_truth_df['img_id'], 'mask_rle': result}
valid_pred = {'img_id': df['img_id'], 'mask_rle': result}
prediction_df = pd.DataFrame(data = valid_pred)

In [None]:
lst_ground_truth_rle = [rle_encode((validset.__getitem__(i))[1]) for i in range(len(valid))]
valid_pred = {'img_id': df['img_id'], 'mask_rle': lst_ground_truth_rle}
ground_truth_df = pd.DataFrame(data = valid_pred)

In [None]:
rle_decode(ground_truth_df['mask_rle'][4458]).shape

In [None]:
calculate_dice_scores(ground_truth_df, prediction_df, img_shape=(1024, 1024))

### Class 별 IOU

In [None]:
predNoBuildingIdx = list(filter(lambda x: result[x] == -1, range(len(result))))

In [None]:
from sklearn.metrics import confusion_matrix

def generateConfusionMatrix(ground_truth_mask, pred_mask):
    y_true = sum(rle_decode(ground_truth_mask).tolist(), [])
    y_pred = sum(rle_decode(pred_mask).tolist(), [])
    cMatrix = confusion_matrix(y_true, y_pred)
    return cMatrix

def generateConfusionMatrixLst(lst_ground_truth_rle, lst_pred_rle):
    lst_cMatrix = Parallel(n_jobs=1)(delayed(generateConfusionMatrix)(lst_ground_truth_rle[i], result[i]) for i in range(len(lst_ground_truth_rle)))
    return lst_cMatrix

In [None]:
Lst_cMatrix = generateConfusionMatrixLst(lst_ground_truth_rle, result)

In [None]:
Lst_cMatrix[0]

In [None]:
def IoU(cMatrix):
    Intersection = cMatrix.diagonal()
    Union11 = cMatrix.sum(axis = 0)[0] + cMatrix[0][1]
    Union22 = cMatrix.sum(axis = 0)[1] + cMatrix[1][0]
    Union = np.array([Union11, Union22])
    return Intersection / Union

# 전체 이미지 IoU 수치에 대하여 평균냄.
def totalIoU(lst_cMatrix):
    totalIoU = np.array([0, 0], dtype = 'float64')
    for cMat in lst_cMatrix:
        totalIoU += IoU(cMat)
    return totalIoU / len(lst_cMatrix)

def eachIoU(lst_cMatrix):
    eachIoU = []
    for cMat in lst_cMatrix:
        eachIoU.append(IoU(cMat))
    return eachIoU

In [None]:
IoU(Lst_cMatrix[0])

In [None]:
totalIoU(Lst_cMatrix)

In [None]:
totaliou = totalIoU(Lst_cMatrix)

In [None]:
def printClassScores(totaliou):
    label = ['background', 'building']
    print('classes          IoU      nIoU')
    print('--------------------------------')
    for i, iou in enumerate(totaliou):
        labelName = label[i]
        iouStr = f'{iou:>5.3f}'
        niouStr = 'empty'
        print('{:<14}: '.format(labelName) + iouStr + '    ' + niouStr)
    print('--------------------------------')
    print(f'Score Average : {(np.sum(totaliou) / 2):>5.3f}' + '    ' + niouStr)

In [None]:
printClassScores(totaliou)

# 제출 코드

In [None]:
submit = pd.read_csv('./sample_submission.csv')
submit['mask_rle'] = result

In [None]:
submit.to_csv('./submit(unetf1).csv', index=False)