In [1]:
import pandas as pd
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import functional
!pip install -U segmentation-models-pytorch
import segmentation_models_pytorch as smp



In [7]:
train_df = pd.read_csv("./data/train_df.csv")
val_df = pd.read_csv("./data/val_df.csv")
test_df = pd.read_csv("./data/test_df.csv")

In [8]:
# 本データセットはラベル有りと無しがほぼ１：１になっているので、全てのデータを使用すると上手く学習が進まなかった為、
# ラベル有りのデータのみで学習を行った。
train_df = train_df[train_df.label == 1].reset_index(drop=True)
val_df = val_df[val_df.label == 1].reset_index(drop=True)
test_df = test_df[test_df.label == 1].reset_index(drop=True)

In [25]:
train_df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,id,case_day,imgpath,height,width,hpix,wpix,labelpath,label
0,64,64,case123_day20_slice_0065,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,1
1,65,65,case123_day20_slice_0066,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,1
2,66,66,case123_day20_slice_0067,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,1
3,67,67,case123_day20_slice_0068,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,1
4,68,68,case123_day20_slice_0069,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,1
...,...,...,...,...,...,...,...,...,...,...,...
15054,38484,38484,case30_day0_slice_0133,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,1
15055,38485,38485,case30_day0_slice_0134,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,1
15056,38486,38486,case30_day0_slice_0135,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,1
15057,38487,38487,case30_day0_slice_0136,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,1


In [19]:
df = pd.read_csv("./data/moddf.csv")
df

Unnamed: 0.1,Unnamed: 0,id,case_day,imgpath,height,width,hpix,wpix,labelpath,label
0,0,case123_day20_slice_0001,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,0
1,1,case123_day20_slice_0002,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,0
2,2,case123_day20_slice_0003,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,0
3,3,case123_day20_slice_0004,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,0
4,4,case123_day20_slice_0005,case123_day20,./data/train/case123/case123_day20/scans/slice...,266,266,1.5,1.5,./data/train/case123/case123_day20/label/case1...,0
...,...,...,...,...,...,...,...,...,...,...
38491,38491,case30_day0_slice_0140,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,0
38492,38492,case30_day0_slice_0141,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,0
38493,38493,case30_day0_slice_0142,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,0
38494,38494,case30_day0_slice_0143,case30_day0,./data/train/case30/case30_day0/scans/slice_01...,266,266,1.5,1.5,./data/train/case30/case30_day0/label/case30_d...,0


In [9]:
# データローダーの作成
class Dataset(BaseDataset):
    def __init__(self, df, transform = None, classes = None, augmentation = None):
        self.imgpath_list = df.imgpath
        self.labelpath_list = df.labelpath

    def __getitem__(self, i):
        imgpath = self.imgpath_list[i]
        img = cv2.imread(imgpath)
        img = cv2.resize(img, dsize=(256, 256))
        img = img/255
        img = torch.from_numpy(img.astype(np.float32)).clone()
        img = img.permute(2, 0, 1)

        labelpath = self.labelpath_list[i]
        label = Image.open(labelpath)
        label = np.asarray(label)
        label = cv2.resize(label, dsize=(256, 256))
        label = torch.from_numpy(label.astype(np.float32)).clone()
        label = torch.nn.functional.one_hot(label.long(), num_classes=4)
        label = label.to(torch.float32)
        label = label.permute(2, 0, 1)

        data = {"img": img, "label": label}
        return data
    
    def __len__(self):
        return len(self.imgpath_list)

In [20]:
BATCH_SIZE = 8
train_dataset = Dataset(df)
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True)

val_dataset = Dataset(val_df)
val_loader = DataLoader(val_dataset,
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        shuffle=True)

test_dataset = Dataset(test_df)
test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         num_workers=4)


## Unetの構築
nn.ModuleListを使用することで短く描くことも可能だが、可読性が低下する為以下のように書く
今回、デコーダーのup-Convolution（高さと幅を2倍にしつつ、チャンネル数を半分にする）については以下の方法で実装している
* nn.Upsampleを使用してup-Convolutionを行い、直後にnn.Conv2d（カーネルサイズは2*2を採用しているが、1*1でも良い）でチャンネル数を半分にする
- 以下の2つの方法でも実装可能
* up-Convolutionの直前でConvolutionブロックでチャンネル数を半分にし、その後nn.Upsampleを使用してup-Convolutionを行う
* nn.ConvTranspose2dを使用してup-Convolutionを行う

In [21]:
class TwoConvBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size = 3, padding="same")
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.rl = nn.ReLU()
        self.conv2 = nn.Conv2d(middle_channels, out_channels, kernel_size = 3, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.rl(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.rl(x)
        return x

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 2, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.up(x)
        x = self.bn1(x)
        x = self.conv(x)
        x = self.bn2(x)
        return x

class UNet_2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.TCB1 = TwoConvBlock(3, 64, 64)
        self.TCB2 = TwoConvBlock(64, 128, 128)
        self.TCB3 = TwoConvBlock(128, 256, 256)
        self.TCB4 = TwoConvBlock(256, 512, 512)
        self.TCB5 = TwoConvBlock(512, 1024, 1024)
        self.TCB6 = TwoConvBlock(1024, 512, 512)
        self.TCB7 = TwoConvBlock(512, 256, 256)
        self.TCB8 = TwoConvBlock(256, 128, 128)
        self.TCB9 = TwoConvBlock(128, 64, 64)
        self.maxpool = nn.MaxPool2d(2, stride = 2)
        
        self.UC1 = UpConv(1024, 512) 
        self.UC2 = UpConv(512, 256) 
        self.UC3 = UpConv(256, 128) 
        self.UC4= UpConv(128, 64)

        self.conv1 = nn.Conv2d(64, 4, kernel_size = 1)
        self.soft = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.TCB1(x)
        x1 = x
        x = self.maxpool(x)

        x = self.TCB2(x)
        x2 = x
        x = self.maxpool(x)

        x = self.TCB3(x)
        x3 = x
        x = self.maxpool(x)

        x = self.TCB4(x)
        x4 = x
        x = self.maxpool(x)

        x = self.TCB5(x)

        x = self.UC1(x)
        x = torch.cat([x4, x], dim = 1)
        x = self.TCB6(x)

        x = self.UC2(x)
        x = torch.cat([x3, x], dim = 1)
        x = self.TCB7(x)

        x = self.UC3(x)
        x = torch.cat([x2, x], dim = 1)
        x = self.TCB8(x)

        x = self.UC4(x)
        x = torch.cat([x1, x], dim = 1)
        x = self.TCB9(x)

        x = self.conv1(x)

        return x


In [22]:
# GPU、最適化アルゴリズムの設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
unet = UNet_2D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [23]:
# 損失関数の設定
# 損失はTversky LossとBCEWithLogits Lossの平均とした。これらの関数は損失関数内でソフトマックス関数を
# 処理する為、UNetの最後にソフトマックス関数を適用しない
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
BCELoss = UNet_2D().to(device)
def criterion(pred, target):
    return 0.5*BCELoss(pred, target) + 0.5*TverskyLoss(pred, target)

In [24]:
# 学習を行う
history = {"train_loss": []}
n = 0
m = 0

for epoch in range(15):
  train_loss = 0
  val_loss = 0

  unet.train()
  for i, data in enumerate(train_loader):
    inputs, labels = data["img"].to(device), data["label"].to(device)
    optimizer.zero_grad()
    outputs = unet(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    history["train_loss"].append(loss.item())
    n += 1
    if i % ((len(df)//BATCH_SIZE)//10) == (len(df)//BATCH_SIZE)//10 - 1:
      print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.5f}")
      n = 0
      train_loss = 0
      train_acc = 0


  unet.eval()
  with torch.no_grad():
    for i, data in enumerate(val_loader):
      inputs, labels = data["img"].to(device), data["label"].to(device)
      outputs = unet(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      m += 1
      if i % (len(val_df)//BATCH_SIZE) == len(val_df)//BATCH_SIZE - 1:
        print(f"epoch:{epoch+1}  index:{i+1}  val_loss:{val_loss/m:.5f}")
        m = 0
        val_loss = 0
        val_acc = 0

  torch.save(unet.state_dict(), f"./train_{epoch+1}.pth")
print("finish training")

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/t-jun/.pyenv/versions/3.10.1/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/t-jun/.pyenv/versions/3.10.1/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'Dataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [None]:
# 損失の推移をプロットする
plt.plot(history["train_loss"])
plt.xlabel('batch')
plt.ylabel('loss')