# Unet
source: https://amaarora.github.io/2020/09/13/unet.html

<img src="https://i.imgur.com/LQORH9i.png" alt="drawing" width="500"/>


In [None]:
BATCH_SIZE = 32
NUM_LABELS = 1
WIDTH = 128
HEIGHT = 128

In [None]:
!pip install -q --user albumentations

In [None]:
import cv2
import numpy as np
import torch # 1.9
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import os
import torch.optim as optim
from utils import *

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import os   
cmd = '''mkdir ./data
wget -q https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937454-1629951595/capsule.tar.xz -O data/MVtech-capsule.tar.xz
tar -Jxf data/MVtech-capsule.tar.xz --overwrite --directory ./data
'''
if not os.path.isdir('./data/capsule'):
    for i in cmd.split('\n'):
        os.system(i)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu' # for debug建議使用cpu作為torch的運行背景
device

## Chapter1 : UNet網路構建

### ConvBlock
- 加入Instance Norm.
- <img src="https://miro.medium.com/max/983/1*p84Hsn4-e60_nZPllkxGZQ.png" width="50%">

> 上圖為一整個batch的feature-map。輸入6張圖片，輸入6chs, 輸出也是6chs(C方向看進去是channel, N方向看進去是圖片)

In [None]:
# # 原始版本
# class convBlock(nn.Module):
#     def __init__(self, in_ch, out_ch):
#         super().__init__()
#         self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
#         self.relu  = nn.ReLU()
#         self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
    
#     def forward(self, x):
#         return self.relu(self.conv2(self.relu(self.conv1(x))))

In [None]:
## 加入instance normalization
class convBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.INorm = torch.nn.InstanceNorm2d(out_ch, affine=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.INorm(x)
        x = self.relu(self.conv2(x))
        x = self.INorm(x)
        return x

In [None]:
block = convBlock(1, 64)
x = torch.randn(1, 1, WIDTH, HEIGHT)
block(x).shape

## Encoder(DownStream)
將影像進行編碼，過程中解析度會縮小(maxpooling、convolution)

In [None]:
class Encoder(nn.Module):
    def __init__(self, chs=(3,32,64,128,256,512)):
        super().__init__()
        self.list_of_blocks = nn.ModuleList([convBlock(chs[i], chs[i+1]) for i in range(len(chs)-1)])
#         self.pool = nn.MaxPool2d(2)
        self.pool = torch.max_pool2d
        
    def forward(self, x):
        features = []
        
        for block in self.list_of_blocks:
            x = block(x)
            features.append(x)
#             print(x.shape)
            x = self.pool(x, kernel_size=2)
        return features

In [None]:
encoder = Encoder()
x = torch.randn(1, 3, WIDTH, HEIGHT)
features = encoder(x)
for f in features:
    print(f.shape)

## Decoder(UpStream)
將編碼還原成影像，過程中解析度會放大直到回復成輸入影像解析度(transposed Convolution)。
- 將編碼還原成影像是因為影像分割是pixel-wise的精度進行預測，解析度被還原後，就可以知道指定pixel位置所對應的類別
- 類別資訊通常用feature-map的channels(chs)去劃分，一個channel代表一個class
- 有許多UNet模型架構會有輸入576x576，但輸出只有388x388的情況，是因為他們沒有對卷積過程做padding，導致解析度自然下降。最後只要把mask resize到388x388就能繼續計算loss。

### Transposed Conv and UpsampleConv
<img src="https://i.imgur.com/eIIJxre.png" alt="drawing" width="300"/>
<img src="https://i.imgur.com/uLo7icF.png" alt="drawing" width="300"/>

Transposed Conv 
- 透過上面的操作做轉置卷積，feature-map上的數值會作為常數與kernel相乘

UpsampleConv
- 先做上採樣(Upsample/ Unpooling)
- 然後作卷積(padding = same)
<!-- #### 替代方案 UpSampling(Unpooling)+Convolution -->


In [None]:
# ConvTranspose2d透過設定k=2, s=2, output_padding=0可以讓影像從28x28變成56x56
x = torch.randn(1, 3, 28, 28)
x = nn.ConvTranspose2d(3, 30, kernel_size=2, stride=2, output_padding=0)(x)
x.shape

In [None]:
class upSampleConvs(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
#         self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.relu  = nn.ReLU()
        self.upSample = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Conv2d(in_ch, out_ch, 2, padding='same')
        self.INorm = torch.nn.InstanceNorm2d(out_ch)
        
    def forward(self, x):
        x = self.upSample(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.INorm(x)
#         return self.relu(self.conv2(self.relu(self.upSample(x))))
        return x

### decoder(上採樣) module

In [None]:
class Decoder(nn.Module):
    def __init__(self, chs=(512, 256, 128, 64, 32)):
        super().__init__()
        self.chs = chs
#         self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])  # 轉置卷積
        self.upconvs = nn.ModuleList([upSampleConvs(chs[i], chs[i+1]) for i in range(len(chs)-1)]) # 上採樣後卷積
        self.list_of_convBlock = nn.ModuleList([convBlock(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x = x
            enc_ftrs = encoder_features[i]
            enc_ftrs = enc_ftrs
            
            x = self.upconvs[i](x)

            x = torch.cat([x, enc_ftrs], dim=1)
            x = self.list_of_convBlock[i](x)
        return x

In [None]:
for i in features:
    print(i.shape)

In [None]:
decoder = Decoder()
decoder
x = torch.randn(1, 512, WIDTH//16, HEIGHT//16)
decoder(x, features[::-1][1:]).shape 

## Unet構建
結合encoder和decoder組成Unet。
- 在輸出層如果用softmax做多元分類問題預測的話，類別數量要+1(num_classes+background)

In [None]:
class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(256,256)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) # 把不同尺度的所有featuremap都輸入decoder，我們在decoder需要做featuremap的拼接
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out

In [None]:
unet = UNet(num_class=1)
unet.to(device)
x    = torch.randn(1, 3, WIDTH, HEIGHT).to(device)
unet(x).shape

## Chapter2 : Test on dataset

In [None]:
from utils import show_image_mask

In [None]:
item = 'capsule'
path = os.getcwd()
img_dir = f'{path}/data/{item}/test/scratch/'
print(len(os.listdir(img_dir)))
anno_dir = f'{path}/data/{item}/ground_truth/scratch/'

defective_number = [i.split('.')[0] for i in os.listdir(img_dir)]

### 取得image list
輸出: data_dic (字典)
- key: X_train, X_test, y_train, y_test

In [None]:
from sklearn.model_selection import train_test_split
mask_dir = lambda anno_dir, X_lis:[anno_dir+i.split('.')[0]+'_mask.png' for i in X_lis]

imgs_path_list = sorted(os.listdir(img_dir))

size = 1/len(imgs_path_list)
size = 0.1
train, test = train_test_split(imgs_path_list, train_size = size)

key = 'X_train, X_test, y_train, y_test'.split(', ')

lis = []
for number in [train, test]:
    lis.append([img_dir+i for i in number]) # X
for number in [train, test]:
    lis.append(mask_dir(anno_dir, number)) # y
data_dic = dict(zip(key, lis))
data_dic['X_train']

### Build torch dataset

In [None]:
#  https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

class CustomImageDataset(Dataset):
    def __init__(self, imgs_path_list, anno_path_list, transform=None, target_transform=None):
        self.imgs_path_list = imgs_path_list
        self.anno_path_list = anno_path_list
        if type(imgs_path_list) != list:
            raise ValueError('Need Input a list')
        if type(anno_path_list) != list:
            raise anno_path_list('Need Input a list')
        self.transform = transform
#   
    def __len__(self):
        return len(self.imgs_path_list)
#         return 32

    def __getitem__(self, idx):
        img_path = self.imgs_path_list[idx]
        file_name = img_path.split('/')[-1].split('.')[0]
        mask_path = [i for i in self.anno_path_list if i.__contains__(file_name)][0]
        
        # cv2
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#       
        # cv2
        mask = cv2.imread(mask_path,  cv2.IMREAD_GRAYSCALE)
        
        if self.transform:
            transformed = transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
            
        return image, mask
    

#### 補充: 如何從dataset抽image, mask出來

In [None]:
# a = CustomImageDataset(data_dic['X_train'], data_dic['y_train'])
# for idx in range(a.__len__):
#     print(idx)

#     test_image = a.__getitem__(0)[0]
#     test_mask = a.__getitem__(0)[1]
#     print(test_image.shape)
#     print(test_mask.shape)
#     show_image_mask(test_image, test_mask)

#### 補充: 取得指定影像的瑕疵位置(x_min, y_min, x_max, y_max)

In [None]:
# from utils import find_objects_contours, center_to_4point
# arr = find_objects_contours(test_mask)
# print(arr)
# points = center_to_4point(test_mask[:, :], arr, WIDTH, pad=50)
# print('4 POINTS , WIDTH + pad', points, points[2]-points[0])
# x_min, y_min, x_max, y_max = points

# show_image_mask(test_mask, test_mask[y_min:y_max, x_min:x_max])

In [None]:
file = data_dic['X_train']
file = np.random.choice(file,size=len(file), replace=False)
file = [i.split('/')[-1].split('.')[0] for i in file]
print(file)
img_mask_list = []
for i in file[:]:
    print(i)
    image = cv2.imread(f'/home/jovyan/git/Image_Segmentation/PART2/data/capsule/test/scratch/{i}.png')
    print(image.shape)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(f'/home/jovyan/git/Image_Segmentation/PART2/data/capsule/ground_truth/scratch/{i}_mask.png', cv2.IMREAD_GRAYSCALE)
#     show_image_mask(image, mask)
    
    img_mask_list.append((image, mask))
len(img_mask_list)

In [None]:
# https://albumentations.ai/docs/getting_started/mask_augmentation/
from utils import mask_CutMix
transform = A.Compose([
#     mask_CutMix(img_mask_list, p=0.5),
#     A.Crop(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, p=0.5),
    
    A.CenterCrop(300, 900, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=[-0.05, 0.05], p=0.2),
    A.Rotate((-30, 30), interpolation=0),  

    A.ToFloat(always_apply=True),
    A.Resize(WIDTH, HEIGHT),
])

# target_transform = A.Compose([
#     A.ToFloat(always_apply=True),
#     A.Resize(WIDTH, HEIGHT),
# ])

In [None]:
lis = data_dic['X_train']
lis = [lis[i%len(lis)] for i in range(BATCH_SIZE)]

In [None]:
data_dic['X_train']

In [None]:
# 建議同時間只有16個(128,128)的sample進行計算 (Total = BATCH_SIZE*MULTIPLE_BATCH)
BATCH_SIZE = 16
MULTIPLE_BATCH = 1 # [MULTIPLE_BATCH>1]

# 在這邊會強制對所有不滿BATCH_SIZE的訓練資料做數量上的匹配，接著透過CustomImageDataset的transform做擴增
if len(data_dic['X_train']) < 16: 
    lis = data_dic['X_train']
    lis = [lis[i%len(lis)] for i in range(BATCH_SIZE)]
    data_dic['X_train'] = lis

dataset_train = CustomImageDataset(data_dic['X_train'], data_dic['y_train'], transform=transform)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)
# dataloader_train = iter(dataloader_train)

dataset_test = CustomImageDataset(data_dic['X_test'], data_dic['y_test'], transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, pin_memory=True)

In [None]:
'''test transform on dataloader_train'''
for data in dataloader_train:
#     print(data.shape)
#     i = 0
    for x, y in zip(*data):
#         print(i)
#         i += 1
        print(x.shape)
        print(y.shape)
        show_image_mask(x, y)

In [None]:
import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(unet, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

In [None]:
model = UNet(num_class=NUM_LABELS)
model.to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight = torch.Tensor([1000]).to(device))
optimizer = optim.Adam(model.parameters(), lr = 1e-4)

# (optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])

In [None]:
# model = torch.load('save.pt')

In [None]:
class patience():
    def __init__(self, patience=None):
        self.patience = patience
        self.record_value = 0
        self.early_stop = False
        self.ini = True
    def record(self):
        if self.patience != -1:
            if self.ini and not self.patience:
                print('No early stop')
                self.ini = False
                self.patience = -1
            else:
                if self.record_value >= self.patience:
                    print('early stop')
                    self.early_stop = True
                self.record_value += 1
    def reset(self):
        self.record_value = 0

In [None]:
best_loss = None
# VALUE = None
monitor = patience(None)
    
for epoch in range(500):  # loop over the dataset multiple times
    torch.cuda.empty_cache()
    running_loss = 0.0
#     sample_i = 0
    
    if monitor.early_stop:
        break
    
    multiple_inputs = torch.Tensor([])
    multiple_mask = torch.Tensor([])
    
    for multiple_batch in range(MULTIPLE_BATCH):
        for i, data in enumerate(dataloader_train, 1): # get each batch, setting drop_last
#             print(i, len(data))
            inputs, mask = data
#             print(len(inputs))
            multiple_inputs = torch.cat((multiple_inputs, inputs),
                                        dim=0)
            multiple_mask = torch.cat((multiple_mask, mask),
                                        dim=0)
    inputs = multiple_inputs.to(device)
    mask = multiple_mask.to(device)

#     print(inputs.shape)
    inputs = inputs.permute(0,3,1,2)
    mask = mask.unsqueeze(1)
    mask = mask.float()/255

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, mask)
    loss.backward()
    optimizer.step()
    
    # print statistics
    running_loss += loss.item()

    running_loss = running_loss/len(multiple_inputs)
    print(f'epochs: {epoch}, average sample loss: {running_loss:.8f}')
    
    monitor.record()
    
#     if monitor.record() and running_loss >= best_loss:

#     print(best_loss, running_loss)
    if not best_loss:
        best_loss = running_loss
#     elif epoch % 100 == 1:
#         torch.save(model, f'{epoch}_save.pt')
    elif running_loss < 0.9*best_loss:
        best_loss = running_loss
        torch.save(model, 'best_save.pt')
        monitor.reset()
        print('save torch model [best_save.pt]')
    

    

    running_loss = 0.0

torch.save(model, 'last_save.pt')
        
# print('Finished Training')

In [None]:
test_set = iter(dataloader_train)

In [None]:
image, mask = test_set.next()
print(image.shape)
mask.max()

In [None]:
image = image.permute(0,3,1,2).to(device)
mask = mask.unsqueeze(1).to(device)
mask = mask.float()
print(image.shape)
print(mask.shape)

outputs = model(image)
print(outputs.shape)
loss = criterion(outputs, mask)
print(loss)


a = image.permute(0,2,3,1).squeeze(0).cpu().numpy()
b = mask.permute(0,2,3,1).squeeze(0).squeeze(-1).cpu().numpy()
# print(a.shape)
# print(b.shape)
# for x,m in zip(a, b):
#     show_image_mask(x, m)
    

# outputs = model(image)
# print(outputs.shape)
c = outputs.permute(0,2,3,1).squeeze(0).squeeze(-1).cpu().detach().numpy()
# print(c.shape)
# b = mask.permute(0,2,3,1).squeeze(0).squeeze(-1).cpu().numpy()
# show_image_mask(a, b)
# print(a.shape)
# print(b.shape)
# show_image_mask(a,b,c)
for x, m, p in zip(a, b, c):
    show_image_mask(x, m, p)