# 各種モジュールのimport

・このプログラムで使うモジュールです。  
・tqdmモジュールについては、notebook形式ならtqdm_notebook、通常のpythonファイルならtqdmでしか動かないみたいです。

In [1]:
import torch
from torch import nn
from torchsummary import summary
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import os
import cv2
import csv
import glob
import random
import numpy as np
from PIL import Image
#from tqdm import tqdm
from tqdm import tqdm_notebook as tqdm
#import matplotlib.pyplot as plt
#%matplotlib inline

# UNetクラスの定義

・UNetの構造ををBlock単位で細分化しています(UNetDownBlockクラス、UNetUpBlockクラス)。  
・Pythonのクラス定義の勉強も兼ねています。

In [2]:
#エンコーダ部分の構成単位
class UNetDownBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, down_mode="max_pooling", kernel_size=None, stride=None, padding=None):
        super(UNetDownBlock, self).__init__()

        self.down_mode = down_mode
        
        #1/2倍縮小：プーリングを使う場合
        if self.down_mode == "max_pooling":
            self.mp = nn.MaxPool2d(2, 2)
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        #1/2倍縮小：畳み込みを使う場合
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        
        #縮小以外のBlockの構成要素
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
        
        

    def forward(self, x):
        
        #down_modeがプーリング以外であればここで縮小
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        #down_modeがプーリングであればここで縮小
        if self.down_mode == "max_pooling":
            x = self.mp(x)

        return x


    
#デコーダ部分の構成単位
class UNetUpBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, up_mode = "interpolation"):
        super(UNetUpBlock, self).__init__()

        #2倍拡大：補間法を使う場合
        if up_mode == "interpolation":
            self.up = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
            )
            
        #2倍拡大：転置畳み込み層を使う場合
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        
        #拡大以外のブロックの構成要素
        self.conv1 = nn.Conv2d(in_channels=out_channels*2, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
        
        

    #ブロック単位の処理の流れ
    def forward(self, x, x_fromE):
        
        #拡大
        x = self.up(x)
        #特徴マップの連結
        x = torch.cat((x, x_fromE), 1)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        return x
        
        

#UNetの本体
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=None, filter_num=16, up_mode="transpose", down_mode="max_pooling"):
        super(UNet, self).__init__()

        #エンコーダ部分の構成要素
        #self.down1ブロックは縮小しない
        self.down1 = UNetDownBlock(n_channels, filter_num, down_mode="conv", kernel_size=3, stride=1, padding=1)
        self.down2 = UNetDownBlock(filter_num, filter_num*2, down_mode=down_mode)
        self.down3 = UNetDownBlock(filter_num*2, filter_num*4, down_mode=down_mode)
        self.down4 = UNetDownBlock(filter_num*4, filter_num*4, down_mode=down_mode)
        #self.down5 = UNetDownBlock(filter_num*8, filter_num*8, down_mode=down_mode)

        #デコーダ部分の構成要素
        #self.up1 = UNetUpBlock(filter_num*8, filter_num*8, up_mode=up_mode)
        self.up2 = UNetUpBlock(filter_num*4, filter_num*4, up_mode=up_mode)
        self.up3 = UNetUpBlock(filter_num*4, filter_num*2, up_mode=up_mode)
        self.up4 = UNetUpBlock(filter_num*2, filter_num, up_mode=up_mode)

        self.conv_final = nn.Sequential(nn.Conv2d(filter_num, n_classes, 1, 1, 0), nn.Softmax())
        
        

    #UNetの処理の流れ
    def forward(self, x):
        
        #エンコーダ部分の処理の流れ
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        #x5 = self.down5(x4)

        #デコーダ部分の処理の流れ
        #x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.conv_final(x)

        return x

# 自前データセットクラスの定義

・自前のデータセットの読み込み方に関する部分です。  
・Data augmentationはFlipとrotateの2種類しか使っていませんが、他にも使いたいものがあれば追加してもらって大丈夫です。  
・low_resource_cropping関数は画像を切り取って入力することで学習時のメモリ不足を解消します。必要なければコメントアウトしてください。  
・教師ラベルの値はひび割れを1、背景を0としています。(torchのラベルはなぜかlong型にしないといけないらしい...)

In [3]:
class CrackDataset(Dataset):
    
    def __init__(self, mode="train", data_source_dir=None, img_preprocessing=None):
        
        self.mode = mode
        
        self.image_dir = sorted(glob.glob(os.path.join(data_source_dir, mode, "img", "*.jpg")))
        self.label_dir = sorted(glob.glob(os.path.join(data_source_dir, mode, "label", "*.jpg")))
        
        self.img_preprocessing = img_preprocessing
        
        
        
    def __len__(self):
        return len(self.image_dir)
    
    
    
    #データセットの読み込み処理の流れ
    def __getitem__(self , index):
        img_path = self.image_dir[index]
        label_path = self.label_dir[index]
        
        #グレースケール画像として読み込み
        img = Image.open(img_path).convert("L")
        label = Image.open(label_path).convert("L")
        
        #trainモードのみ、DataAugmentationとCroppingを行う
        if self.mode == "train":
            img, label = data_augmentation(img, label)
            img, label = low_resource_cropping(img, label)
        
        #testモードの場合、ネットワークの入力に適したサイズに変換する必要がある
        elif self.mode == "test":
            img, label = resize_test_img(img, label)
            
        #前処理によってtorch.tensorに変換
        img = self.img_preprocessing(img)
        label = label_preprocessing(label)
        
        return img, label
        

        
#学習用画像のData Augmentation
def data_augmentation(input_img, input_label, low_resource_ver=False):
    
    #random horizontal flip
    if random.random() < 0.5:
        input_img = input_img.transpose(Image.FLIP_LEFT_RIGHT)
        input_label = input_label.transpose(Image.FLIP_LEFT_RIGHT)
        
    #random rotation
    r = random.random()
    if r > 0.75:
        input_img = input_img.rotate(270)
        input_label = input_label.rotate(270)
    elif r > 0.5:
        input_img = input_img.rotate(180)
        input_label = input_label.rotate(180)
    elif r > 0.25:
        input_img = input_img.rotate(90)
        input_label = input_label.rotate(90)
        
    return input_img, input_label


        
#GPUのリソースが少ない人向けのCropping(256×256 → 192×192)
#4GB程度の人は下記関数を使うことをオススメする
def low_resource_cropping(input_img, input_label):
    
    crop_size=192
    w, h = input_img.size
    x1 = random.randint(0, w - crop_size)
    y1 = random.randint(0, h - crop_size)
    input_img = input_img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
    input_label = input_label.crop((x1, y1, x1 + crop_size, y1 + crop_size))

    return input_img, input_label



#教師ラベルの前処理
def label_preprocessing(label):
    
    label = np.where(np.array(label) > 127, 1, 0)
    label = torch.from_numpy(label).to(torch.long)
    
    return label
        

# 学習時の各種パラメータ設定


In [4]:
#GPUの認識
device = torch.device("cuda:0")
#入力画像のチャネル数
n_channels = 1
#クラス数(今回はひび割れ or 背景の2クラス)
n_classes = 2
#UNetのフィルタ数の基本単位
filter_num = 16
#オプティマイザの学習率
lr = 5e-4
#学習時のバッチサイズ
batch_size = 16
#入力画像の前処理
img_preprocessing = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
#データセットが存在するディレクトリのパス指定
data_source_dir = "/home/es1video10/datasets/Crack/segmentation_test"
#学習エポック数
num_epochs = 50
#最良ロスの定義(初期値はinf)
best_loss = float("inf")
#実験結果のディレクトリのパス指定
result_dir = "/home/es1video10/datasets/Crack/result"



#UNetの定義
model = UNet(n_channels=n_channels, n_classes=n_classes, filter_num=filter_num, up_mode="interpolation")
#UNetのパラメータをGPUへ
model = model.to(device)
#損失関数の定義
criterion = nn.CrossEntropyLoss().cuda()
#オプティマイザの定義
optimizer = optim.Adam(model.parameters(), lr=lr)
#スケジューラの定義
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=3, verbose=True)
#ネットワーク構成の確認
summary(model, (1,192,192))


#データセットの定義
train_dataset = CrackDataset(mode="train", data_source_dir=data_source_dir, img_preprocessing=img_preprocessing)
validation_dataset = CrackDataset(mode="validation", data_source_dir=data_source_dir, img_preprocessing=img_preprocessing)

#データローダーの定義
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=False)




----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 192, 192]             160
       BatchNorm2d-2         [-1, 16, 192, 192]              32
              ReLU-3         [-1, 16, 192, 192]               0
            Conv2d-4         [-1, 16, 192, 192]           2,320
       BatchNorm2d-5         [-1, 16, 192, 192]              32
              ReLU-6         [-1, 16, 192, 192]               0
     UNetDownBlock-7         [-1, 16, 192, 192]               0
            Conv2d-8         [-1, 32, 192, 192]           4,640
       BatchNorm2d-9         [-1, 32, 192, 192]              64
             ReLU-10         [-1, 32, 192, 192]               0
           Conv2d-11         [-1, 32, 192, 192]           9,248
      BatchNorm2d-12         [-1, 32, 192, 192]              64
             ReLU-13         [-1, 32, 192, 192]               0
        MaxPool2d-14           [-1, 32,

  input = module(input)


# 学習時の処理の流れ

・プログレスバーをtqdmで表示している。もしtqdmを使わないなら、tbarの部分をdataloaderに置き換える。  
・損失が最小になったタイミングのみモデルを保存する。best_lossの初期値はinfなので1エポック目は絶対に保存される。

In [5]:
for epoch in range(num_epochs):
    
    print(f"Starting epoch: {epoch}")
    
    losses = 0.0
    total_batches = len(train_dataloader)
    #UNet(特にbatch_norm)をtrainモードへ
    model.train()
    #プログレスバー表示準備
    tbar = tqdm(train_dataloader)
    #オプティマイザの勾配情報を初期化
    optimizer.zero_grad()
    
    for itr, batch in enumerate(tbar):
        
        #データローダーの出力タプルを入力画像とラベルに分割
        img, label = batch
        #入力画像をGPUへ
        img = img.to(device)
        #ラベルをGPUへ
        label = label.to(device)
        #入力画像をUNetに流して出力を得る(順伝搬)
        output = model(img)
        #出力とラベルから損失を計算
        loss = criterion(output, label)
        #損失から勾配情報を算出(逆伝搬)
        loss.backward()
        #勾配情報からUNetの各層の重みとバイアスを更新
        optimizer.step()
        #オプティマイザの勾配情報を初期化
        optimizer.zero_grad()
        #処理済みのバッチの累計損失を計算
        losses += loss.item()
        #バッチ処理中の平均損失を表示
        tbar.set_description('loss: %.7f' % (losses / (itr + 1)))
        
    #エポック終了時の平均損失の算出
    train_epoch_loss = losses / total_batches
        
    
    #勾配情報の算出を省略
    with torch.no_grad():
        losses = 0.0
        total_batches = len(validation_dataloader)
        #UNet(特にbatch_norm)をevalモードへ
        model.eval()
        tbar = tqdm(validation_dataloader)

        for itr, batch in enumerate(tbar):

            img, label = batch
            img = img.to(device)
            label = label.to(device)
            output = model(img)
            loss = criterion(output, label)
            losses += loss.item()
            tbar.set_description('loss: %.7f' % (losses / (itr + 1)))
                
        validation_epoch_loss = losses / total_batches
        #スケジューラの更新
        scheduler.step(validation_epoch_loss)
    
    
    
    #CSVファイルへログの書き出し
    with open((os.path.join(result_dir, 'training_log.csv')), 'a') as f:
        
        writer = csv.writer(f)
        if epoch == 0:
            writer.writerow(["epoch", "train_loss", "validation_loss"])
        writer.writerow([epoch+1, train_epoch_loss, validation_epoch_loss])
    
    
    
    #ロスが最小のときモデルのパラメータを保存する
    if validation_epoch_loss < best_loss:
        
        best_loss = validation_epoch_loss
        state = {
            "epoch": epoch + 1,
            "state_dict": model.state_dict(),
            "best_loss": best_loss,
        }
        filename = os.path.join(result_dir, "checkpoint.pth.tar")
        torch.save(state, filename)
        print("----------new optimal found! saving state---------- ")
        
    

Starting epoch: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 1


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 2


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 3


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 4


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 5


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 6


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 7


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 8


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 9


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 10


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 11


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 12


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 13


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 14


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 15


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 16


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 17


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 18


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 19


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 20


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 21


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 22


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 23


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 24


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 25


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 26


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 27


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 28


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 29


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 30


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 31


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 32


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 33


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 34


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 35


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Epoch    36: reducing learning rate of group 0 to 5.0000e-05.
Starting epoch: 36


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 37


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 38


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 39


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 40


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Epoch    41: reducing learning rate of group 0 to 5.0000e-06.
Starting epoch: 41


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 42


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 43


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 44


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 45


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Epoch    46: reducing learning rate of group 0 to 5.0000e-07.
Starting epoch: 46


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 47


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 48


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 49


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Epoch    50: reducing learning rate of group 0 to 5.0000e-08.
Starting epoch: 50


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 51


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 52


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 53


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Epoch    54: reducing learning rate of group 0 to 5.0000e-09.
Starting epoch: 54


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 55


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 56


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 57


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 58


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 59


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 60


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 61


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 62


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 63


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 64


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 65


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 66


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 67


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 68


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 69


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 70


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 71


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 72


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 73


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 74


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 75


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 76


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 77


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 78


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 79


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 80


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 81


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 82


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 83


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


----------new optimal found! saving state---------- 
Starting epoch: 84


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 85


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 86


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 87


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 88


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 89


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 90


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 91


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 92


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 93


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 94


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 95


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 96


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 97


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 98


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Starting epoch: 99


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))




# テスト時のみ使う関数の定義

・テスト画像はサイズがバラバラなため、resize_test_imgでUNetに入力できるサイズに変形する。  
・ひび割れのクラスを対象にして適合率と再現率を計算する。他の指標も使いたい場合は各自追加してください。

In [120]:
def resize_test_img(img, label):
    
    img_size = 192
    
    for w in range(1, 100):
        if w * img_size >= img.size[0]:
            width = w * img_size
            break
    for h in range(1, 100):
        if h * img_size >= img.size[1]:
            height = h * img_size
            break
            
    img = img.resize((width, height))
    label = label.resize((width, height))
    
    return img, label



def save_img(output, label, directory, itr):
    
    save_dir = os.path.join(result_dir, "maxprob")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    cv2.imwrite(os.path.join(save_dir, str(itr) + ".jpg"), output*255)
    
    

def get_score(output, label, result_dir, itr):
    
    #多クラスの場合
    """
    tp = torch.empty(n_classes)
    fp = torch.empty(n_classes)
    fn = torch.empty(n_classes)
    tn = torch.empty(n_classes)
    print(output.shape)
    
    for i in range(n_classes):
        tp[i] = ((output == i) & (label == i)).sum().item()
        fp[i] = ((output == i) & (label != i)).sum().item()
        fn[i] = ((output != i) & (label == i)).sum().item()
        tn[i] = ((output != i) & (label != i)).sum().item()
    
    pre = tp.sum() / (tp + fp).sum()
    rec = tp.sum() / (tp + fn).sum()
    """

    #単一クラスの場合(label=1: ひび割れ)
    detected_label = 1
    tp = ((output == detected_label) & (label == detected_label)).sum().item()
    fp = ((output == detected_label) & (label != detected_label)).sum().item()
    fn = ((output != detected_label) & (label == detected_label)).sum().item()
    tn = ((output != detected_label) & (label != detected_label)).sum().item()
    precision = 100 * tp / (tp + fp)
    recall = 100 * tp / (tp + fn)
    
    with open((os.path.join(result_dir, 'eval.csv')), 'a') as f:
        
        writer = csv.writer(f)
        if itr == 0:
            writer.writerow(["image_number", "precision", "recall"])
        writer.writerow([str(itr) + ".jpg", precision, recall])

# テスト時の各種パラメータ設定


In [121]:
device = torch.device("cuda:0")

n_channels = 1
n_classes = 2
filter_num = 16
img_preprocessing = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
data_source_dir = "/home/es1video10/datasets/Crack/segmentation_test"
result_dir = "/home/es1video10/datasets/Crack/result"
check_path = result_dir + "/checkpoint.pth.tar"

model = UNet(n_channels=n_channels, n_classes=n_classes, filter_num=filter_num, up_mode="interpolation")
model = model.to(device)
checkpoint = torch.load(check_path)
model.load_state_dict(checkpoint["state_dict"])

test_dataset = CrackDataset(mode="test", data_source_dir=data_source_dir, img_preprocessing=img_preprocessing)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


# テスト時の処理の流れ

In [123]:
tbar = tqdm(test_dataloader)
model.eval()
for itr, batch in enumerate(tbar):
    with torch.no_grad():
        img, label = batch
        img = img.to(device)
        output = model(img)
        output = output.cpu().detach().numpy()
        #不必要な軸の削減
        output = np.squeeze(output, 0)
        #出力確率の最大値を算出
        output = np.argmax(output, axis=0)
        label = label.cpu().numpy()
        
        get_score(output, label, result_dir, itr)
        save_img(output, label, result_dir, itr)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))


