In [1]:
import pandas as pd
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import functional
import albumentations as A
import segmentation_models_pytorch as smp

In [2]:
train_3D8_df = pd.read_csv("./data/train_3D8_df.csv")
val_3D8_df = pd.read_csv("./data/val_3D8_df.csv")
test_3D8_df = pd.read_csv("./data/test_3D8_df.csv")

train_df = train_3D8_df[train_3D8_df.label == 1].reset_index(drop=True)  #ラベル有のデータのみ抽出
val_df = val_3D8_df[val_3D8_df.label == 1].reset_index(drop=True)  #ラベル有のデータのみ抽出
test_df = test_3D8_df[test_3D8_df.label == 1].reset_index(drop=True)  #ラベル有のデータのみ抽出

print(train_df.shape)
print(train_df.columns)

(2100, 22)
Index(['Unnamed: 0', 'id', 'caseday', 'imgpath_0', 'imgpath_1', 'imgpath_2',
       'imgpath_3', 'imgpath_4', 'imgpath_5', 'imgpath_6', 'imgpath_7',
       'labelpath_0', 'labelpath_1', 'labelpath_2', 'labelpath_3',
       'labelpath_4', 'labelpath_5', 'labelpath_6', 'labelpath_7', 'height',
       'width', 'label'],
      dtype='object')


In [3]:
# データローダーの作成
additional_image_targets = {f"image{i+1}": "image" for i in range(8 - 1)}
additional_label_targets = {f"mask{i+1}": "mask" for i in range(8 - 1)}
additional_targets8 = dict(additional_image_targets, **additional_label_targets)

data_transforms8 = A.Compose([
      A.HorizontalFlip(p=0.5),
      A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
      A.OneOf([
          A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
          A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
      ], p=0.25),],
      additional_targets = additional_targets8, p=1.0)

## 以下のコードの説明
1. 各深さの画像トラベルを読み込み、(256,256)にリサイズした後深さ方向に結合
2. Albumentationsの引数とするための辞書型のデータを作成
3. Albumentationsによるデータ拡張
4. 変換後のデータをサイド深さ方向に結合
5. 画像は各RGBの値を0~1に変換（255で割る）
6. ラベルは0(background),1(large_bowel),2(small_bowel),3(stomach)で表されているため、これをOne-Hotエンコーディング
7. 画像・ラベルデータを（バッチサイズ、チャンネル数、深さ、高さ、幅）に変換

In [4]:
class Dataset(BaseDataset):
    def __init__(self, df):
        self.depth = 8
        self.imgpath_list = [[df.iloc[i,:][f"imgpath_{h}"] for h in range(self.depth)] for i in range(len(df))]
        self.labelpath_list = [[df.iloc[i,:][f"labelpath_{h}"] for h in range(self.depth)] for i in range(len(df))]
        self.transform8 = data_transforms8

    def __getitem__(self, i):
        imgpaths = self.imgpath_list[i]
        labelpaths = self.labelpath_list[i]
        for j in range(self.depth): # 画像とラベルデータの読み込み
            img = cv2.imread(imgpaths[j])
            img = cv2.resize(img, dsize=(256,256))
            label = Image.open(labelpaths[j])
            label = np.asarray(label)
            label = cv2.resize(label,dsize=(256,256))
            if j == 0: # 深さ方向に結合
                img_3D = [img]
                label_3D = [label]
            else:
                img_3D = np.vstack([img_3D, [img]])
                label_3D = np.vstack([label_3D, [label]])

        d1 = {"image": img_3D[0,:,:,:]} # Albumentationsに代入するための辞書型データを作成
        d2 = {f"image{i+1}": img_3D[i+1,:,:,:] for i in range(self.depth - 1)}
        d3 = {"mask": additional_label_targets[0,:,:]}
        d4 = {f"mask{i+1}": label_3D[i+1,:,:] for i in range(self.depth - 1)}
        dic = dict(d1, **d2, **d3, **d4)

        transformed = self.transform8(**dic)

        for j in range(self.depth):
            if j == 0: # データ拡張後のデータを再度深さ方向に結合
                img_3D = [transformed["image"]]
                label_3D = [transformed["mask"]]
            else:
                img_3D = np.vstack([img_3D], [transformed[f"image{j}"]])
                label_3D = np.vstack([label_3D, [transformed[f"mask{j}"]]])

        img_3D = img_3D/255 # RGBの値を0~1に
        img_3D = torch.from_numpy(img_3D.astype(np.float32)).clone()
        img_3D = img_3D.permute(3, 0, 1, 2) # (チャンネル数、深さ、高さ、幅)に変換
        label_3D = torch.from_numpy(label_3D.astype(np.float32)).clone()
        label_3D = torch.nn.functional.one_hot(label_3D.long(), num_classes=4)
        label_3D = label_3D.to(torch.float32)
        label_3D = label_3D.permute(3, 0, 1, 2) # チャンネル数、深さ、高さ、幅）に変換
        data = {"img": img_3D, "label": label_3D}
        return data
    
    def __len__(self):
        return len(self.imgpath_list)
    
class valtest_Dataset(BaseDataset): # Albumentationsによるデータ拡張を行わない
    def __init__(self, df):
        self.depth = 8
        self.imgpath_list = [[df.iloc[i,:][f"imgpath_{h}"] for h in range(self.depth)] for i in range(len(df))]
        self.labelpath_list = [[df.iloc[i,:][f"labelpath_{h}"] for h in range(self.depth)] for i in range(len(df))]

    def __getitem__(self, i):
        imgpaths = self.imgpath_list[i]
        labelpaths = self.labelpath_list[i]
        for j in range(self.depth):
            img = cv2.imread(imgpaths[j])
            label = Image.open(labelpaths[j])
            label = np.asarray(label)
            img = cv2.resize(img, dsize=(256,256))
            label = cv2.resize(label, dsize=(256,256))
            if j == 0:
                img_3D = [img]
                label_3D = [label]
            else:
                img_3D = np.vstack([img_3D, [img]])
                label_3D = np.vstack([label_3D, [label]])

        img_3D = img_3D/255
        img_3D = torch.from_numpy(img_3D.astype(np.float32)).clone()
        img_3D = img_3D.permute(3, 0, 1, 2)
        label_3D = torch.from_numpy(label_3D.astype(np.float32)).clone()
        label_3D = torch.nn.functional.one_hot(label_3D.long(), num_classes=4)
        label_3D = label_3D.to(torch.float32)
        label_3D = label_3D.permute(3, 0, 1, 2)
        data = {"img": img_3D, "label": label_3D}
        return data
    
    def __len__(self):
        return len(self.imgpath_list)

In [5]:
BATCH_SIZE = 3
train_dataset = Dataset(train_df)
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True)

val_dataset = valtest_Dataset(val_df)
val_loader = DataLoader(val_dataset,
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        shuffle=True)

test_dataset = valtest_Dataset(test_df)
test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         num_workers=4)

## 2.5DのUNetモデルを構築

In [6]:
class TwoConvBlock_2D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding="same")
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.rl = nn.ReLU()
        self.conv2 = nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.rl(x)
        x = self.conv2(x)
        x = self.bn2
        x = self.rl(x)
        return x
    
class TwoConvBlock_3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding="same")
        self.bn1 = nn.BatchNorm3d(middle_channels)
        self.rl = nn.ReLU()
        self.conv2 = nn.Conv3d(middle_channels, out_channels, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.rl(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.rl(x)
        return x
    
class UpConv_2D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.up(x)
        x = self.bn1(x)
        x = self.conv(x)
        x = self.bn2(x)
        return x
    
class UpConv_3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
        self.bn1 = nn.BatchNorm3d(in_channels)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=2, padding="same")
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = self.up(x)
        x = self.bn1(x)
        x = self.conv(x)
        x = self.bn2(x)
        return x
    
class UNet_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.TCB1 = TwoConvBlock_3D(3, 64, 64)
        self.TCB2 = TwoConvBlock_3D(64, 128, 128)
        self.TCB3 = TwoConvBlock_3D(128, 256, 256)
        self.TCB4 = TwoConvBlock_2D(256, 512, 512)
        self.TCB5 = TwoConvBlock_2D(512, 1024, 1024)
        self.TCB6 = TwoConvBlock_2D(1024, 512, 512)
        self.TCB7 = TwoConvBlock_3D(512, 256, 256)
        self.TCB8 = TwoConvBlock_3D(256, 128, 128)
        self.TCB9 = TwoConvBlock_3D(128, 64, 64)

        self.maxpool_3D = nn.MaxPool3d(2, stride=2)
        self.maxpool_2D = nn.MaxPool2d(2, stride=2)

        self.UC1 = UpConv_2D(1024, 512)
        self.UC2 = UpConv_2D(512, 256)
        self.UC3 = UpConv_3D(256, 128)
        self.UC4 = UpConv_3D(128, 64)

        self.conv1 = nn.Conv3d(64, 4, kernel_size=1)

    def forward(self, x):
        x = self.TCB1(x)
        x1 = x
        x = self.maxpool_3D(x)

        x = self.TCB2(x)
        x2 = x
        x = self.maxpool_3D(x)

        x = self.TCB3(x)
        x3 = x
        x_1, x_2 = x[:,:,0,:,:], x[:,:,1,:,:]
        x_1, x_2 = self.maxpool_2D(x_1), self.maxpool_2D(x_2)

        x_1, x_2 = self.TCB4(x_1), self.TCB4(x_1)
        x4_1, x4_2 = x_1, x_2
        x_1, x_2 = self.maxpool_2D(x_1), self.maxpool_2D(x_2)

        x_1, x_2 = self.TCB5(x_1), self.maxpool_2D(x_1)

        x_1, x_2 = self.UC1(x_1), self.UC1(x_2)
        x_1, x_2 = torch.cat([x4_1, x_1], dim=1), torch.cat([x4_2, x_2], dim=1)
        x_1, x_2 = self.TCB6(x_1), self.TCB6(x_1)

        x_1, x_2 = self.UC2(x_1), self.UC2(x_2)
        x = torch.cat([torch.unsqueeze(x_1, 2), torch.unsqueeze(x_2, 2)], dim=2)
        x = torch.cat([x3, x], dim=1)
        x = self.TCB7(x)

        x = self.UC3(x)
        x = torch.cat([x2, x], dim=1)
        x = self.TCB8(x)

        x = self.UC4(x)
        x = torch.cat([x1, x], dim=1)
        x = self.TCB9(x)

        x = self.conv1(x)

        return x


In [7]:
# GPU, 最適化アルゴリズムの設定を行う
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

unet = UNet_3D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [8]:
# 損失関数の設定
# 損失はTversky LossとBCEWithLogits Lossの平均とした。
# これらの関数は損失関数内でソフトマックス関数を処理する為、UNetの最後にソフトマックス関数を適用しない
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
BCELoss = smp.losses.SoftBCEWithLogitsLoss()
def criterion(pred, target):
    return 0.5*BCELoss(pred, target) + 0.5*TverskyLoss(pred, target)


In [9]:
# 学習を行う
history = {"train_loss": []}
n = 0
m = 0

for epoch in range(15):
    train_loss = 0
    val_loss = 0

    unet.train()
    for i, data in enumerate(train_loader):
        inputs, labels = data["img"].to(device), data["label"].to(device)
        optimizer.zero_grad()
        outputs = unet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        history["train_loss"].append(loss.item())
        n += 1
        if i % ((len(train_df)//BATCH_SIZE)//10) == (len(train_df)//BATCH_SIZE)//10 - 1:
            print(f"epoch:{epoch+1} index:{i+1} train_loss:{train_loss/n:.5f}")
            n = 0
            train_loss = 0

    unet.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, labels = data["img"].to(device), data["label"].to(device)
            outputs = unet(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            m += 1
            if i % (len(val_df)//BATCH_SIZE) == len(val_df)//BATCH_SIZE - 1:
                print(f"epoch:{epoch+1} index:{i+1} val_loss:{val_loss/m:.5f}")
                m = 0
                val_loss = 0

    torch.save(unet.state_dict(), f"./train_depth8_{epoch+1}.pth")
print("finish trainning")

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/t-jun/.pyenv/versions/3.10.1/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/t-jun/.pyenv/versions/3.10.1/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'Dataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [None]:
# 損失の推移をプロット。
plt.plot(history["train_loss"])
plt.xlabel('batch')
plt.ylabel('loss')

In [None]:
# testデータに対して予測を行う
model = UNet_3D()
model.load_state_dict(torch.load("./train_"))