In [1]:
import os
import zarr
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import UNet_2D, aug
from src.utils import save_images
from src.metric import score, create_cls_pos, create_cls_pos_sikii, create_df
from metric import visualize_epoch_results

In [None]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

for data in train_loader:
    break

In [None]:
data["normalized_tomogram"].shape

In [None]:
# input:1, 92, 320, 320

# 2.5D-Unetの実装
import torch
from icecream import ic


class DoubleConv2D(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch, kernel_size=3, padding="same"):
        super(DoubleConv2D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, out_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # x: (batch_size, in_ch, H, W)
        x = self.conv(x)
        return x


class DoubleConv3D(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch, kernel_size=3, padding="same"):
        super(DoubleConv3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, mid_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm3d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_ch, out_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # x: (batch_size, in_ch, D, H, W)
        x = self.conv(x)
        return x


double_conv2d_layer = DoubleConv2D(1, 64, 64)
input_ = torch.randn(1, 1, 320, 320)
out_ = double_conv2d_layer(input_)
print(f"doube_conv2d_layer: {out_.shape}")

double_conv3d_layer = DoubleConv3D(1, 64, 64)
input_ = torch.randn(1, 1, 92, 320, 320)
out_ = double_conv3d_layer(input_)
print(f"doube_conv3d_layer: {out_.shape}")

In [37]:
class UpConv2D(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding="same"):
        super(UpConv2D, self).__init__()
        self.upsample_layer = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.BatchNorm2d(in_ch),
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
        )

    def forward(self, x):
        x = self.upsample_layer(x)
        return x


class UpConv3D(nn.Module):
    def __init__(
        self, in_ch, out_ch, kernel_size=3, padding="same", scale_factor=2, size=None
    ):
        super(UpConv3D, self).__init__()

        if size is not None:
            self.upsample_layer = nn.Sequential(
                nn.Upsample(size=size, mode="trilinear", align_corners=True),
                nn.BatchNorm3d(in_ch),
                nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
                nn.BatchNorm3d(out_ch),
            )
        else:
            self.upsample_layer = nn.Sequential(
                nn.Upsample(
                    scale_factor=scale_factor, mode="trilinear", align_corners=True
                ),
                nn.BatchNorm3d(in_ch),
                nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
                nn.BatchNorm3d(out_ch),
            )

    def forward(self, x):
        x = self.upsample_layer(x)
        return x


upconv2d_layer = UpConv2D(64, 64)
input_ = torch.randn(1, 64, 160, 160)
out_ = upconv2d_layer(input_)
print(f"upconv2d_layer: {out_.shape}")

upconv3d_layer = UpConv3D(64, 64)
input_ = torch.randn(1, 64, 46, 160, 160)
out_ = upconv3d_layer(input_)
print(f"upconv3d_layer: {out_.shape}")

upconv2d_layer: torch.Size([1, 64, 320, 320])
upconv3d_layer: torch.Size([1, 64, 92, 320, 320])


In [None]:
class Unet25D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet25D, self).__init__()

        down1_ch = 64
        down2_ch = 64
        down3_ch = 64
        down4_ch = 64
        down5_ch = 64

        self.down1_factor = 4
        self.down2_factor = 4
        self.down3_factor = 4
        self.down4_factor = 4

        self.down1 = DoubleConv3D(in_ch, down1_ch, down1_ch)
        self.down2 = DoubleConv3D(down1_ch, down2_ch, down2_ch)
        self.down3 = DoubleConv3D(down2_ch, down3_ch, down3_ch)
        self.down4 = DoubleConv3D(down3_ch, down4_ch, down4_ch)
        self.down5 = DoubleConv3D(down4_ch, down5_ch, down5_ch)

        middle1_ch = 64
        middle2_ch = 64
        middle3_ch = 64

        self.middle1 = DoubleConv2D(down5_ch, middle1_ch, middle2_ch)
        self.middle2 = DoubleConv2D(middle2_ch, middle2_ch, middle3_ch)
        self.middle3 = DoubleConv2D(middle3_ch + middle3_ch, middle3_ch, middle3_ch)

        self.up_middle_2d = UpConv2D(middle3_ch, middle3_ch)

        self.up1 = UpConv3D(middle3_ch + down5_ch, down4_ch, scale_factor=4)
        self.up1_conv = DoubleConv3D(down4_ch, down4_ch, down4_ch)
        self.up2 = UpConv3D(
            down4_ch + down4_ch, down3_ch, scale_factor=2, size=(25, 80, 80)
        )
        self.up2_conv = DoubleConv3D(down3_ch, down3_ch, down3_ch)
        self.up3 = UpConv3D(down3_ch + down3_ch, down2_ch, scale_factor=2)
        self.up3_conv = DoubleConv3D(down2_ch, down2_ch, down2_ch)
        self.up4 = UpConv3D(down2_ch + down2_ch, down1_ch, scale_factor=2)
        self.up4_conv = DoubleConv3D(down1_ch, down1_ch, down1_ch)
        self.up5 = UpConv3D(down1_ch + down1_ch, down1_ch, scale_factor=2)
        self.up5_conv = DoubleConv3D(down1_ch, down1_ch, down1_ch)

        self.outconv = nn.Conv3d(64, out_ch, kernel_size=1)

    def forward(self, x):
        # x: (batch_size, in_ch, D, H, W)

        ##################### down conv #####################
        x1 = self.down1(x)
        ic(x1.shape)
        pooled_x1 = torch.nn.functional.max_pool3d(x1, self.down1_factor)
        ic(pooled_x1.shape)
        x2 = self.down2(pooled_x1)
        ic(x2.shape)
        pooled_x2 = torch.nn.functional.max_pool3d(x2, self.down2_factor)
        ic(pooled_x2.shape)
        x3 = self.down3(pooled_x2)
        ic(x3.shape)
        pooled_x3 = torch.nn.functional.max_pool3d(x3, self.down3_factor)
        ic(pooled_x3.shape)
        x4 = self.down4(pooled_x3)
        ic(x4.shape)
        pooled_x4 = torch.nn.functional.max_pool3d(x4, self.down4_factor)
        ic(pooled_x4.shape)
        x5 = self.down5(pooled_x4)
        ic(x5.shape)

        ##################### middle conv #####################
        channel1_2d = x5[:, :, 0, :, :]
        channel2_2d = x5[:, :, 1, :, :]
        channel3_2d = x5[:, :, 2, :, :]

        # middle1
        m1_pooled_channel1_2d = torch.nn.functional.max_pool2d(channel1_2d, 2)
        #  ic(m1_pooled_channel1_2d.shape)
        m1_pooled_channel2_2d = torch.nn.functional.max_pool2d(channel2_2d, 2)
        #  ic(m1_pooled_channel2_2d.shape)
        m1_pooled_channel3_2d = torch.nn.functional.max_pool2d(channel3_2d, 2)
        #  ic(m1_pooled_channel3_2d.shape)

        m1_conv_channel1_2d = self.middle1(m1_pooled_channel1_2d)
        #  ic(m1_conv_channel1_2d.shape)
        m1_conv_channel2_2d = self.middle1(m1_pooled_channel2_2d)
        #  ic(m1_conv_channel2_2d.shape)
        m1_conv_channel3_2d = self.middle1(m1_pooled_channel3_2d)
        #  ic(m1_conv_channel3_2d.shape)

        # middle2
        m2_pooled_channel1_2d = torch.nn.functional.max_pool2d(m1_conv_channel1_2d, 2)
        #  ic(m2_pooled_channel1_2d.shape)
        m2_pooled_channel2_2d = torch.nn.functional.max_pool2d(m1_conv_channel2_2d, 2)
        #  ic(m2_pooled_channel2_2d.shape)
        m2_pooled_channel3_2d = torch.nn.functional.max_pool2d(m1_conv_channel3_2d, 2)
        #  ic(m2_pooled_channel3_2d.shape)

        m2_conv_channel1_2d = self.middle2(m2_pooled_channel1_2d)
        #  ic(m2_conv_channel1_2d.shape)
        m2_conv_channel2_2d = self.middle2(m2_pooled_channel2_2d)
        #  ic(m2_conv_channel2_2d.shape)
        m2_conv_channel3_2d = self.middle2(m2_pooled_channel3_2d)
        #  ic(m2_conv_channel3_2d.shape)

        up_m2_channel1_2d = torch.nn.functional.interpolate(
            m2_conv_channel1_2d, scale_factor=2.5, mode="bilinear", align_corners=True
        )
        #  ic(up_m2_channel1_2d.shape)
        up_m2_channel2_2d = torch.nn.functional.interpolate(
            m2_conv_channel2_2d, scale_factor=2.5, mode="bilinear", align_corners=True
        )
        #  ic(up_m2_channel2_2d.shape)
        up_m2_channel3_2d = torch.nn.functional.interpolate(
            m2_conv_channel3_2d, scale_factor=2.5, mode="bilinear", align_corners=True
        )
        #  ic(up_m2_channel3_2d.shape)

        # middle3(merge)
        ch1_merge = torch.cat([m1_conv_channel1_2d, up_m2_channel1_2d], dim=1)
        #  ic(ch1_merge.shape)
        ch2_merge = torch.cat([m1_conv_channel2_2d, up_m2_channel2_2d], dim=1)
        #  ic(ch2_merge.shape)
        ch3_merge = torch.cat([m1_conv_channel3_2d, up_m2_channel3_2d], dim=1)
        #  ic(ch3_merge.shape)

        m3_conv_channel1_2d = self.middle3(ch1_merge)
        #  ic(m3_conv_channel1_2d.shape)  # torch.Size([1, 256, 5, 5])
        m3_up_channel1_2d = self.up_middle_2d(m3_conv_channel1_2d).unsqueeze(2)
        #  ic(m3_up_channel1_2d.shape)
        m3_conv_channel2_2d = self.middle3(ch2_merge)
        #  ic(m3_conv_channel2_2d.shape)
        m3_up_channel2_2d = self.up_middle_2d(m3_conv_channel2_2d).unsqueeze(2)
        #  ic(m3_up_channel2_2d.shape)
        m3_conv_channel3_2d = self.middle3(ch3_merge)
        #  ic(m3_conv_channel3_2d.shape)
        m3_up_channel3_2d = self.up_middle_2d(m3_conv_channel3_2d).unsqueeze(2)
        #  ic(m3_up_channel3_2d.shape)

        unet_out2d = torch.cat(
            [m3_up_channel1_2d, m3_up_channel2_2d, m3_up_channel3_2d], dim=2
        )
        ic(unet_out2d.shape)

        ##################### up conv #####################
        up1_input = torch.cat([unet_out2d, x5], dim=1)
        ic(up1_input.shape)
        up1 = self.up1(up1_input)
        ic(up1.shape)
        up1_conv = self.up1_conv(up1)
        ic(up1_conv.shape)

        up2_input = torch.cat([up1_conv, x4], dim=1)
        ic(up2_input.shape)
        up2 = self.up2(up2_input)
        ic(up2.shape)
        up2_conv = self.up2_conv(up2)
        ic(up2_conv.shape)

        up3_input = torch.cat([up2_conv, x3], dim=1)
        ic(up3_input.shape)
        up3 = self.up3(up3_input)
        ic(up3.shape)
        up3_conv = self.up3_conv(up3)
        ic(up3_conv.shape)

        up4_input = torch.cat([up3_conv, x2], dim=1)
        ic(up4_input.shape)
        up4 = self.up4(up4_input)
        ic(up4.shape)
        up4_conv = self.up4_conv(up4)
        ic(up4_conv.shape)

        up5_input = torch.cat([up4_conv, x1], dim=1)
        ic(up5_input.shape)
        up5 = self.up5(up5_input)
        ic(up5.shape)
        up5_conv = self.up5_conv(up5)
        ic(up5_conv.shape)


unet25d = Unet25D(1, 7)
input_ = torch.randn(1, 1, 100, 320, 320)  # (batch_size, in_ch, D, H, W)

out_ = unet25d(input_)

ic| x1.shape: torch.Size([1, 64, 100, 320, 320])
ic| pooled_x1.shape: torch.Size([1, 64, 50, 160, 160])
ic| x2.shape: torch.Size([1, 64, 50, 160, 160])
ic| pooled_x2.shape: torch.Size([1, 64, 25, 80, 80])
ic| x3.shape: torch.Size([1, 64, 25, 80, 80])
ic| pooled_x3.shape: torch.Size([1, 64, 12, 40, 40])
ic| x4.shape: torch.Size([1, 64, 12, 40, 40])
ic| pooled_x4.shape: torch.Size([1, 64, 3, 10, 10])
ic| x5.shape: torch.Size([1, 64, 3, 10, 10])
ic| unet_out2d.shape: torch.Size([1, 64, 3, 10, 10])
ic| up1_input.shape: torch.Size([1, 128, 3, 10, 10])
ic| up1.shape: torch.Size([1, 64, 12, 40, 40])
ic| up1_conv.shape: torch.Size([1, 64, 12, 40, 40])
ic| up2_input.shape: torch.Size([1, 128, 12, 40, 40])
ic| up2.shape: torch.Size([1, 64, 25, 80, 80])
ic| up2_conv.shape: torch.Size([1, 64, 25, 80, 80])
ic| up3_input.shape: torch.Size([1, 128, 25, 80, 80])
ic| up3.shape: torch.Size([1, 64, 50, 160, 160])
ic| up3_conv.shape: torch.Size([1, 64, 50, 160, 160])
ic| up4_input.shape: torch.Size([1, 12

In [7]:
x = torch.randn(1, 512, 2, 2)
# (1,512,2,2)を(1,512,5,5)にアップサンプリング
x = torch.nn.functional.interpolate(x, size=(5, 5), mode="bilinear", align_corners=True)
x.shape

torch.Size([1, 512, 5, 5])

In [9]:
x = torch.randn(1, 512, 2, 2)
# (1,512,2,2)を(1,512,5,5)にアップサンプリング
x = torch.nn.functional.interpolate(
    x, scale_factor=2.5, mode="bilinear", align_corners=True
)
x.shape

torch.Size([1, 512, 5, 5])

In [None]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
)

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
)

from tqdm import tqdm

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
train_nshuffle_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

for data in tqdm(train_dataset):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

# normalized_tomogram =
normalized_tomogram.shape[0]

In [3]:
import torch
import torchvision.transforms.functional as F
import random

In [None]:
model = UNet_2D().to("cuda")
model.eval()


optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
criterion = nn.CrossEntropyLoss(
    # weight=torch.tensor([0, 0.1000, 14.4163, 14.1303, 1.0000, 2.2055, 4.4967]).to(
    #     "cuda"
    # )
    weight=torch.tensor([0.5, 32, 32, 32, 32, 32, 32]).to("cuda")
)
# criterion = DiceLoss()

best_model = None
best_score = 0
batch_size = 4

for epoch in range(100):
    train_loss = []
    valid_loss = []
    train_pred_tomogram = defaultdict(list)
    train_gt_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)
    model.train()
    tq = tqdm(range(len(train_loader) * normalized_tomogram.shape[0] // batch_size))
    for data in train_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"]
        segmentation_map = data["segmentation_map"].long()

        for i in range(batch_size, tomogram.shape[1], batch_size):
            optimizer.zero_grad()
            from_, to_ = 0, tomogram.shape[1]
            random_index = random.sample(range(from_, to_), batch_size)
            input_ = tomogram[:, random_index]
            input_ = input_.permute(1, 0, 2, 3)  # (batch_size, 1, 160, 160)
            gt = segmentation_map[:, random_index].squeeze()  # (batch_size, 160, 160)

            # input_, gt = aug(input_, gt)

            input_ = input_.to("cuda")
            gt = gt.to("cuda")
            output = model(input_)
            loss = criterion(output, gt)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            tq.set_description(f"Train-Epoch: {epoch}, Loss: {np.mean(train_loss)}")
            tq.update(1)

    tq.close()

    ############################################# train-nshuffle #############################################
    model.eval()
    train_loss = []
    tq = tqdm(range(len(train_nshuffle_loader) * normalized_tomogram.shape[0]))
    for data in train_nshuffle_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"].to("cuda")
        segmentation_map = data["segmentation_map"].to("cuda").long()

        for i in range(tomogram.shape[1]):
            input_ = tomogram[:, i].unsqueeze(0)
            gt = segmentation_map[:, i]

            output = model(input_)
            loss = criterion(output, gt)

            train_loss.append(loss.item())
            tq.set_description(
                f"Train-nshuffle-Epoch: {epoch}, Loss: {np.mean(train_loss)}"
            )
            tq.update(1)

            train_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
            train_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    train_score_ = visualize_epoch_results(
        train_pred_tomogram,
        train_gt_tomogram,
        sikii_dict=CFG.initial_sikii,
    )

    print(f"EPOCH: {epoch}, TRAIN_SCORE: {train_score_}")

    # 可視化
    index = 20

    plt.figure(figsize=(10, 5))

    ax = plt.subplot(1, 4, 1)
    ax.imshow(train_pred_tomogram[exp_name][index].squeeze(0).argmax(0))
    ax.set_title("Train-Prediction")
    ax.axis("off")

    ax = plt.subplot(1, 4, 2)
    ax.imshow(train_gt_tomogram[exp_name][index].squeeze(0))
    ax.set_title("Train-Ground Truth")
    ax.axis("off")

    ############################################# valid #############################################

    model.eval()
    tq = tqdm(range(len(valid_loader) * normalized_tomogram.shape[0]))
    for data in valid_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"].to("cuda")
        segmentation_map = data["segmentation_map"].to("cuda").long()

        for i in range(tomogram.shape[1]):
            input_ = tomogram[:, i].unsqueeze(0)
            gt = segmentation_map[:, i]

            output = model(input_)
            loss = criterion(output, gt)

            valid_loss.append(loss.item())
            tq.set_description(f"Valid-Epoch: {epoch}, Loss: {np.mean(valid_loss)}")
            tq.update(1)

            valid_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
            valid_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    valid_score_ = visualize_epoch_results(
        valid_pred_tomogram,
        valid_gt_tomogram,
        sikii_dict=CFG.initial_sikii,
    )

    print(f"EPOCH: {epoch}, VALID_SCORE: {valid_score_}")

    if valid_score_ > best_score:
        best_score = valid_score_
        best_model = model
        torch.save(model.state_dict(), f"best_model.pth")

    ax = plt.subplot(1, 4, 3)
    ax.imshow(valid_pred_tomogram[exp_name][index].argmax(1).squeeze(0))
    ax.set_title("Valid-Prediction")
    ax.axis("off")

    ax = plt.subplot(1, 4, 4)
    ax.imshow(valid_gt_tomogram[exp_name][index].squeeze(0))
    ax.set_title("Valid-Ground Truth")
    ax.axis("off")

    plt.tight_layout()

    plt.show()

    # save_images(
    #     train_gt_tomogram=train_gt_tomogram,
    #     train_pred_tomogram=train_pred_tomogram,
    #     valid_gt_tomogram=valid_gt_tomogram,
    #     valid_pred_tomogram=valid_pred_tomogram,
    #     save_dir="images",
    #     epoch=epoch,
    # )

In [None]:
train_pred_tomogram["TS_5_4"][index].squeeze(0).argmax(0).shape

In [None]:
train_gt_tomogram["TS_5_4"][index].squeeze(0).shape

In [None]:
# 可視化
index = 50

plt.figure(figsize=(10, 5))

# ax = plt.subplot(1, 4, 1)
# ax.imshow(train_pred_tomogram[exp_name][index].argmax(0))
# ax.set_title("Train-Prediction")
# ax.axis("off")

# ax = plt.subplot(1, 4, 2)
# ax.imshow(train_gt_tomogram[exp_name][index])
# ax.set_title("Train-Ground Truth")
# ax.axis("off")

ax = plt.subplot(1, 2, 1)
ax.imshow(valid_pred_tomogram[exp_name][index].argmax(1).squeeze(0))
ax.set_title("Valid-Prediction")
ax.axis("off")

ax = plt.subplot(1, 2, 2)
ax.imshow(valid_gt_tomogram[exp_name][index].squeeze(0))
ax.set_title("Valid-Ground Truth")
ax.axis("off")

plt.tight_layout()

plt.show()

In [None]:
train_score_ = visualize_epoch_results(
    train_pred_tomogram,
    train_gt_tomogram,
    sikii_dict=CFG.initial_sikii,
)

print(f"EPOCH: {epoch}, VALID_SCORE: {train_score_}")