In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from src.inference import inference, inference2pos, create_gt_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    augmentation=False,
    train=True,
    slice=True,
    pre_read=True,
)

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=CFG.num_workers,
)

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 1000/1000 [00:01<00:00, 527.74it/s]
100%|██████████| 2/2 [00:00<00:00,  2.76it/s]
  0%|          | 0/500 [00:01<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [3]:
# from tqdm import tqdm

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CFG.batch_size,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# # train_nshuffle_loader = DataLoader(
# #     train_nshuffle_dataset,
# #     batch_size=1,
# #     shuffle=True,
# #     drop_last=True,
# #     pin_memory=True,
# #     num_workers=CFG.num_workers,
# # )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

# for data in tqdm(train_loader):
#     normalized_tomogram = data["normalized_tomogram"]
#     segmentation_map = data["segmentation_map"]
#     break

# normalized_tomogram.shape

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
# model.load_state_dict(torch.load("./pretrained_model.pth"))

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [8]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [9]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [10]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([2, 16, 320, 320])

In [None]:
best_model = None
best_score = -100

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    valid_loss = []
    tq = tqdm(train_loader)
    for data in train_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        optimizer.zero_grad()
        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"train-loss: {np.mean(train_loss)}")
    tq.close()

    print(f"Train-Epoch {epoch} Loss: {np.mean(train_loss)}")

    tq = tqdm(valid_loader)
    for data in valid_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        valid_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"valid-loss: {np.mean(valid_loss)}")
    tq.close()

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # モデルの保存
    torch.save(model.state_dict(), "./pretrained_model.pth")

    # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # for exp_name in tqdm(CFG.train_exp_names):
    for exp_name in CFG.train_exp_names[:5]:  # 5つのデータで試す
        # inferenced_array = inference(model, exp_name, train=False)
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        train_nshuffle_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        train_mean_scores.append(score_)

    print("train_mean_scores", np.mean(train_mean_scores))

    for exp_name in CFG.valid_exp_names:
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        valid_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        valid_mean_scores.append(score_)

    print("valid_mean_scores", np.mean(valid_mean_scores))

    if np.mean(valid_mean_scores) > best_score:
        best_score = np.mean(valid_mean_scores)
        best_model = model.state_dict()
        torch.save(best_model, f"./best_model.pth")

train-loss: 0.8606905264854431: 100%|██████████| 500/500 [01:44<00:00,  4.80it/s]


Train-Epoch 0 Loss: 0.8606905264854431


valid-loss: 0.7577860057353973: 100%|██████████| 2/2 [00:00<00:00,  4.87it/s]


train_mean_scores 0.3179335659041008
valid_mean_scores 0.17423056838125558


train-loss: 0.6172601999044418: 100%|██████████| 500/500 [01:43<00:00,  4.83it/s]


Train-Epoch 1 Loss: 0.6172601999044418


valid-loss: 0.7309835851192474: 100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


train_mean_scores 0.4954239003759898
valid_mean_scores 0.2019440665712497


train-loss: 0.5845060131549835: 100%|██████████| 500/500 [01:42<00:00,  4.87it/s]


Train-Epoch 2 Loss: 0.5845060131549835


valid-loss: 0.611615777015686: 100%|██████████| 2/2 [00:00<00:00,  3.28it/s] 


train_mean_scores 0.46926526795780205
valid_mean_scores 0.27381716260378597


train-loss: 0.4549642549753189: 100%|██████████| 500/500 [01:42<00:00,  4.87it/s] 


Train-Epoch 3 Loss: 0.4549642549753189


valid-loss: 0.5861867964267731: 100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


train_mean_scores 0.8377513513061281
valid_mean_scores 0.384294547358236


train-loss: 0.36806325829029085: 100%|██████████| 500/500 [01:43<00:00,  4.83it/s]


Train-Epoch 4 Loss: 0.36806325829029085


valid-loss: 0.6724430024623871: 100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


train_mean_scores 0.782200972595888
valid_mean_scores 0.29600330051056584


train-loss: 0.33501707124710084: 100%|██████████| 500/500 [01:47<00:00,  4.63it/s]


Train-Epoch 5 Loss: 0.33501707124710084


valid-loss: 0.7766251862049103: 100%|██████████| 2/2 [00:00<00:00,  3.21it/s]


train_mean_scores 0.8800393497597669
valid_mean_scores 0.34657184105414474


train-loss: 0.3209351167678833: 100%|██████████| 500/500 [01:41<00:00,  4.92it/s] 


Train-Epoch 6 Loss: 0.3209351167678833


valid-loss: 0.5220563411712646: 100%|██████████| 2/2 [00:00<00:00,  3.08it/s]


train_mean_scores 0.8341668083595037
valid_mean_scores 0.3055884116087696


train-loss: 0.3347017543315887: 100%|██████████| 500/500 [01:48<00:00,  4.63it/s] 


Train-Epoch 7 Loss: 0.3347017543315887


valid-loss: 0.7350813448429108: 100%|██████████| 2/2 [00:00<00:00,  3.01it/s]


train_mean_scores 0.8137772488648227
valid_mean_scores 0.3528339417064978


train-loss: 0.2886935865879059: 100%|██████████| 500/500 [01:42<00:00,  4.86it/s] 


Train-Epoch 8 Loss: 0.2886935865879059


valid-loss: 0.6589553356170654: 100%|██████████| 2/2 [00:00<00:00,  3.25it/s]


train_mean_scores 0.7382213458002356
valid_mean_scores 0.3579868174620685


train-loss: 0.29180422580242155: 100%|██████████| 500/500 [01:42<00:00,  4.87it/s]


Train-Epoch 9 Loss: 0.29180422580242155


valid-loss: 0.514196515083313: 100%|██████████| 2/2 [00:00<00:00,  3.05it/s] 


train_mean_scores 0.85179358613197
valid_mean_scores 0.3902499223052711


train-loss: 0.2309711172580719: 100%|██████████| 500/500 [01:44<00:00,  4.77it/s] 


Train-Epoch 10 Loss: 0.2309711172580719


valid-loss: 0.4869721531867981: 100%|██████████| 2/2 [00:00<00:00,  2.98it/s]


train_mean_scores 0.8827629643329031
valid_mean_scores 0.43380435821857694


train-loss: 0.1413713308572769: 100%|██████████| 500/500 [01:48<00:00,  4.61it/s] 


Train-Epoch 11 Loss: 0.1413713308572769


valid-loss: 0.5017892122268677: 100%|██████████| 2/2 [00:00<00:00,  2.91it/s]


train_mean_scores 0.8953899430383387
valid_mean_scores 0.39006621428157956


train-loss: 0.07938541138172149: 100%|██████████| 500/500 [01:45<00:00,  4.75it/s]


Train-Epoch 12 Loss: 0.07938541138172149


valid-loss: 0.3906286954879761: 100%|██████████| 2/2 [00:00<00:00,  2.88it/s] 


train_mean_scores 0.9182221200724163
valid_mean_scores 0.4314873407808785


train-loss: 0.06799339652061462: 100%|██████████| 500/500 [01:46<00:00,  4.70it/s]


Train-Epoch 13 Loss: 0.06799339652061462


valid-loss: 0.4366742670536041: 100%|██████████| 2/2 [00:00<00:00,  2.92it/s]


train_mean_scores 0.8461414127948756
valid_mean_scores 0.4286667746294582


train-loss: 0.06112874817848206: 100%|██████████| 500/500 [01:44<00:00,  4.78it/s] 


Train-Epoch 14 Loss: 0.06112874817848206


valid-loss: 0.5063241422176361: 100%|██████████| 2/2 [00:00<00:00,  3.05it/s]


train_mean_scores 0.8823111483257678
valid_mean_scores 0.39441796485651737


train-loss: 0.06073605787754059: 100%|██████████| 500/500 [01:43<00:00,  4.81it/s] 


Train-Epoch 15 Loss: 0.06073605787754059


valid-loss: 0.49334457516670227: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s]


train_mean_scores 0.865581285424633
valid_mean_scores 0.4125096267783527


train-loss: 0.05986571955680847: 100%|██████████| 500/500 [01:43<00:00,  4.83it/s] 


Train-Epoch 16 Loss: 0.05986571955680847


valid-loss: 0.4942372441291809: 100%|██████████| 2/2 [00:00<00:00,  2.80it/s]


train_mean_scores 0.8642183140452712
valid_mean_scores 0.39869112453608824


train-loss: 0.051771758198738096: 100%|██████████| 500/500 [01:42<00:00,  4.89it/s]


Train-Epoch 17 Loss: 0.051771758198738096


valid-loss: 0.43918293714523315: 100%|██████████| 2/2 [00:00<00:00,  2.91it/s]


train_mean_scores 0.8187610146023049
valid_mean_scores 0.4056761615223155


train-loss: 0.04739393723011017: 100%|██████████| 500/500 [01:45<00:00,  4.76it/s] 


Train-Epoch 18 Loss: 0.04739393723011017


valid-loss: 0.5849719047546387: 100%|██████████| 2/2 [00:00<00:00,  2.73it/s]


train_mean_scores 0.8763223486810379
valid_mean_scores 0.4275444792007422


train-loss: 0.04855063033103943: 100%|██████████| 500/500 [01:44<00:00,  4.78it/s] 


Train-Epoch 19 Loss: 0.04855063033103943


valid-loss: 0.5822555124759674: 100%|██████████| 2/2 [00:00<00:00,  2.85it/s]


train_mean_scores 0.8944256125927176
valid_mean_scores 0.4070483192381875


train-loss: 0.044132370591163636: 100%|██████████| 500/500 [01:44<00:00,  4.80it/s]


Train-Epoch 20 Loss: 0.044132370591163636


valid-loss: 0.49203333258628845: 100%|██████████| 2/2 [00:00<00:00,  2.82it/s]


train_mean_scores 0.8717993982871857
valid_mean_scores 0.45336838889379094


train-loss: 0.04575019383430481: 100%|██████████| 500/500 [01:44<00:00,  4.78it/s] 


Train-Epoch 21 Loss: 0.04575019383430481


valid-loss: 0.4415571391582489: 100%|██████████| 2/2 [00:00<00:00,  2.82it/s] 


train_mean_scores 0.8946774923835875
valid_mean_scores 0.41629815954342897


train-loss: 0.04305964875221253: 100%|██████████| 500/500 [01:52<00:00,  4.46it/s] 


Train-Epoch 22 Loss: 0.04305964875221253


valid-loss: 0.40283089876174927: 100%|██████████| 2/2 [00:00<00:00,  2.80it/s]


train_mean_scores 0.8954227239187397
valid_mean_scores 0.4444002165636186


train-loss: 0.05155762243270874: 100%|██████████| 500/500 [01:43<00:00,  4.85it/s] 


Train-Epoch 23 Loss: 0.05155762243270874


valid-loss: 0.24846705794334412: 100%|██████████| 2/2 [00:00<00:00,  2.83it/s] 


train_mean_scores 0.8896207890269601
valid_mean_scores 0.414933024073044


train-loss: 0.04288215839862824: 100%|██████████| 500/500 [01:49<00:00,  4.58it/s] 


Train-Epoch 24 Loss: 0.04288215839862824


valid-loss: 0.4581364393234253: 100%|██████████| 2/2 [00:00<00:00,  2.41it/s]
