In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

# valid_dataset = EziiDataset(
#     exp_names=CFG.valid_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.valid_zarr_types,
#     augmentation=False,
#     train=True,
# )

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 1600/1600 [01:00<00:00, 26.50it/s]
  0%|          | 0/800 [00:01<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [3]:
from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

  0%|          | 0/800 [00:00<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
model.load_state_dict(torch.load("./pretrained_model.pth"))

<All keys matched successfully>

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [8]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [9]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [10]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([2, 16, 320, 320])

In [11]:
def last_padding(tomogram, slice_size):
    # tomogram: (tensor)
    b, d, h, w = tomogram.shape
    last_padding = slice_size - d % slice_size
    if last_padding == slice_size:
        return tomogram
    else:
        return torch.cat(
            [tomogram, torch.zeros(b, last_padding, h, w).to(tomogram.device)], dim=1
        )


def inference(model, exp_name, train=True):
    dataset = EziiDataset(
        exp_names=[exp_name],
        base_dir="../../inputs/train/",
        particles_name=CFG.particles_name,
        resolution=CFG.resolution,
        zarr_type=["denoised"],
        train=train,
        slice=False,
    )
    res_array = CFG.original_img_shape[CFG.resolution]
    pred_array = np.zeros(
        (len(CFG.particles_name) + 1, res_array[0], res_array[1], res_array[2])
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)
    model.eval()
    # tq = tqdm(loader)
    for data in loader:  # 実験データ1つを取り出す
        for i in range(0, data["normalized_tomogram"].shape[1], CFG.slice_):
            normalized_tomogram = data["normalized_tomogram"][:, i : i + CFG.slice_]
            normalized_tomogram = last_padding(normalized_tomogram, CFG.slice_)
            normalized_tomogram = padf(normalized_tomogram)
            normalized_tomogram = preprocess_tensor(normalized_tomogram).to("cuda")
            pred = model(normalized_tomogram)
            prob_pred = (
                torch.softmax(pred, dim=1).detach().cpu().numpy()
            )  # torch.Size([1, 7, 32, 320, 320])
            range_ = min(i + CFG.slice_, res_array[0])
            hw_pad_diff = prob_pred.shape[-1] - res_array[-1]

            if i >= res_array[0]:
                continue

            if range_ == res_array[0]:
                pred_array[:, i:range_] += prob_pred[
                    0, :, : res_array[0] - i, :-hw_pad_diff, :-hw_pad_diff
                ]
            else:
                pred_array[:, i:range_] += prob_pred[
                    0, :, :range_, :-hw_pad_diff, :-hw_pad_diff
                ]
            # tq.update()
    # tq.close()

    return pred_array  # (7, 92, 315, 315)


inferenced_array = inference(model, "TS_6_6", train=False)

In [12]:
best_model = None
best_score = -100

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    tq = tqdm(train_loader)
    for data in train_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        optimizer.zero_grad()
        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"train-loss: {np.mean(train_loss)}")
    tq.close()

    print(f"Epoch {epoch} Loss: {np.mean(train_loss)}")

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # # for exp_name in tqdm(CFG.train_exp_names):
    # for exp_name in tqdm(CFG.train_exp_names[:5]):  # 5つのデータで試す
    #     inferenced_array = inference(model, exp_name, train=False)
    #     train_nshuffle_pred_tomogram[exp_name] = inferenced_array

    #     mean_score, scores = visualize_epoch_results(
    #         train_nshuffle_pred_tomogram,
    #         base_dir="../../inputs/train/overlay/ExperimentRuns/",
    #         sikii_dict=CFG.initial_sikii,
    #     )
    #     train_mean_scores.append(mean_score)
    # print("train_mean_scores", np.mean(train_mean_scores))

    # for exp_name in tqdm(CFG.valid_exp_names):
    #     inferenced_array = inference(model, exp_name, train=False)
    #     valid_pred_tomogram[exp_name] = inferenced_array

    #     mean_score, scores = visualize_epoch_results(
    #         valid_pred_tomogram,
    #         base_dir="../../inputs/train/overlay/ExperimentRuns/",
    #         sikii_dict=CFG.initial_sikii,
    #     )
    #     valid_mean_scores.append(mean_score)
    # print("valid_mean_scores", np.mean(valid_mean_scores))

train-loss: 0.6624469801783561: 100%|██████████| 800/800 [02:17<00:00,  5.81it/s]


Epoch 0 Loss: 0.6624469801783561


train-loss: 0.6405818329751491: 100%|██████████| 800/800 [02:23<00:00,  5.57it/s]


Epoch 1 Loss: 0.6405818329751491


train-loss: 0.6321039402484894: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 2 Loss: 0.6321039402484894


train-loss: 0.611375813037157: 100%|██████████| 800/800 [02:23<00:00,  5.58it/s] 


Epoch 3 Loss: 0.611375813037157


train-loss: 0.6011395563185215: 100%|██████████| 800/800 [02:17<00:00,  5.80it/s]


Epoch 4 Loss: 0.6011395563185215


train-loss: 0.6090522456914187: 100%|██████████| 800/800 [02:18<00:00,  5.78it/s]


Epoch 5 Loss: 0.6090522456914187


train-loss: 0.6072986172884702: 100%|██████████| 800/800 [02:18<00:00,  5.77it/s]


Epoch 6 Loss: 0.6072986172884702


train-loss: 0.6048934152722358: 100%|██████████| 800/800 [02:20<00:00,  5.68it/s]


Epoch 7 Loss: 0.6048934152722358


train-loss: 0.5969574902951718: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 8 Loss: 0.5969574902951718


train-loss: 0.6002933179587125: 100%|██████████| 800/800 [02:20<00:00,  5.71it/s]


Epoch 9 Loss: 0.6002933179587125


train-loss: 0.5963222941756249: 100%|██████████| 800/800 [02:22<00:00,  5.61it/s]


Epoch 10 Loss: 0.5963222941756249


train-loss: 0.5943533488363028: 100%|██████████| 800/800 [02:24<00:00,  5.53it/s]


Epoch 11 Loss: 0.5943533488363028


train-loss: 0.5889377856999636: 100%|██████████| 800/800 [02:18<00:00,  5.77it/s]


Epoch 12 Loss: 0.5889377856999636


train-loss: 0.5827234892547131: 100%|██████████| 800/800 [02:18<00:00,  5.77it/s]


Epoch 13 Loss: 0.5827234892547131


train-loss: 0.5799073057621718: 100%|██████████| 800/800 [02:20<00:00,  5.69it/s]


Epoch 14 Loss: 0.5799073057621718


train-loss: 0.5910176834464074: 100%|██████████| 800/800 [02:23<00:00,  5.58it/s]


Epoch 15 Loss: 0.5910176834464074


train-loss: 0.5811186078190803: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 16 Loss: 0.5811186078190803


train-loss: 0.5769034108519554: 100%|██████████| 800/800 [02:19<00:00,  5.75it/s]


Epoch 17 Loss: 0.5769034108519554


train-loss: 0.5717871566861867: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 18 Loss: 0.5717871566861867


train-loss: 0.5672702055424452: 100%|██████████| 800/800 [02:18<00:00,  5.78it/s]


Epoch 19 Loss: 0.5672702055424452


train-loss: 0.5713388593494892: 100%|██████████| 800/800 [02:18<00:00,  5.77it/s]


Epoch 20 Loss: 0.5713388593494892


train-loss: 0.5661281104385852: 100%|██████████| 800/800 [02:18<00:00,  5.79it/s]


Epoch 21 Loss: 0.5661281104385852


train-loss: 0.5659131701290607: 100%|██████████| 800/800 [02:20<00:00,  5.70it/s]


Epoch 22 Loss: 0.5659131701290607


train-loss: 0.561773831397295: 100%|██████████| 800/800 [02:22<00:00,  5.61it/s] 


Epoch 23 Loss: 0.561773831397295


train-loss: 0.5598040994256734: 100%|██████████| 800/800 [02:19<00:00,  5.72it/s]


Epoch 24 Loss: 0.5598040994256734


train-loss: 0.5591699589043856: 100%|██████████| 800/800 [02:22<00:00,  5.60it/s]


Epoch 25 Loss: 0.5591699589043856


train-loss: 0.5502177231758832: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 26 Loss: 0.5502177231758832


train-loss: 0.5533889973163605: 100%|██████████| 800/800 [02:20<00:00,  5.71it/s]


Epoch 27 Loss: 0.5533889973163605


train-loss: 0.5558807776123286: 100%|██████████| 800/800 [02:18<00:00,  5.76it/s]


Epoch 28 Loss: 0.5558807776123286


train-loss: 0.557206136641437:  91%|█████████▏| 730/800 [02:13<00:23,  3.00it/s] 

KeyboardInterrupt: 

In [13]:
# モデルの保存
torch.save(model.state_dict(), "./pretrained_model.pth")

# ############### validation ################
train_nshuffle_original_tomogram = defaultdict(list)
train_nshuffle_pred_tomogram = defaultdict(list)
train_nshuffle_gt_tomogram = defaultdict(list)

valid_original_tomogram = defaultdict(list)
valid_pred_tomogram = defaultdict(list)
valid_gt_tomogram = defaultdict(list)

train_mean_scores = []
valid_mean_scores = []

# for exp_name in tqdm(CFG.train_exp_names):
for exp_name in tqdm(CFG.train_exp_names):  # 5つのデータで試す
    inferenced_array = inference(model, exp_name, train=False)
    train_nshuffle_pred_tomogram[exp_name] = inferenced_array

    mean_score, scores = visualize_epoch_results(
        train_nshuffle_pred_tomogram,
        base_dir="../../inputs/train/overlay/ExperimentRuns/",
        sikii_dict=CFG.initial_sikii,
    )
    train_mean_scores.append(mean_score)
print("train_mean_scores", np.mean(train_mean_scores))

for exp_name in tqdm(CFG.valid_exp_names):
    inferenced_array = inference(model, exp_name, train=False)
    valid_pred_tomogram[exp_name] = inferenced_array

    mean_score, scores = visualize_epoch_results(
        valid_pred_tomogram,
        base_dir="../../inputs/train/overlay/ExperimentRuns/",
        sikii_dict=CFG.initial_sikii,
    )
    valid_mean_scores.append(mean_score)
print("valid_mean_scores", np.mean(valid_mean_scores))

  0%|          | 0/32 [00:01<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [None]:
np.unique(train_nshuffle_pred_tomogram["TS_5_4"].argmax(0), return_counts=True)