In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from src.inference import inference, inference2pos, create_gt_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    augmentation=False,
    train=True,
    slice=True,
    pre_read=True,
)

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=CFG.num_workers,
)

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 320/320 [00:59<00:00,  5.37it/s]
100%|██████████| 2/2 [00:00<00:00,  2.78it/s]
  0%|          | 0/160 [00:01<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [3]:
# from tqdm import tqdm

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CFG.batch_size,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# # train_nshuffle_loader = DataLoader(
# #     train_nshuffle_dataset,
# #     batch_size=1,
# #     shuffle=True,
# #     drop_last=True,
# #     pin_memory=True,
# #     num_workers=CFG.num_workers,
# # )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

# for data in tqdm(train_loader):
#     normalized_tomogram = data["normalized_tomogram"]
#     segmentation_map = data["segmentation_map"]
#     break

# normalized_tomogram.shape

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
model.load_state_dict(torch.load("./pretrained_model.pth"))

<All keys matched successfully>

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [8]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [9]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [10]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([2, 16, 320, 320])

In [None]:
best_model = None
best_score = -100

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    valid_loss = []
    tq = tqdm(train_loader)
    for data in train_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        optimizer.zero_grad()
        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"train-loss: {np.mean(train_loss)}")
    tq.close()

    print(f"Train-Epoch {epoch} Loss: {np.mean(train_loss)}")

    tq = tqdm(valid_loader)
    for data in valid_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        valid_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"valid-loss: {np.mean(valid_loss)}")
    tq.close()

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # モデルの保存
    torch.save(model.state_dict(), "./pretrained_model.pth")

    # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # for exp_name in tqdm(CFG.train_exp_names):
    for exp_name in CFG.train_exp_names[:5]:  # 5つのデータで試す
        # inferenced_array = inference(model, exp_name, train=False)
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        train_nshuffle_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        train_mean_scores.append(score_)

    print("train_mean_scores", np.mean(train_mean_scores))

    for exp_name in CFG.valid_exp_names:
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        valid_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        valid_mean_scores.append(score_)

    print("valid_mean_scores", np.mean(valid_mean_scores))

    if np.mean(valid_mean_scores) > best_score:
        best_score = np.mean(valid_mean_scores)
        best_model = model.state_dict()
        torch.save(best_model, f"./best_model.pth")

train-loss: 0.6851808901876211: 100%|██████████| 160/160 [00:28<00:00,  5.66it/s]


Train-Epoch 0 Loss: 0.6851808901876211


valid-loss: 0.6962358057498932: 100%|██████████| 2/2 [00:00<00:00,  5.91it/s]


train_mean_scores 0.27378830802200815
valid_mean_scores 0.18431041163218764


train-loss: 0.6620063908398152: 100%|██████████| 160/160 [00:27<00:00,  5.89it/s]


Train-Epoch 1 Loss: 0.6620063908398152


valid-loss: 0.5831620693206787: 100%|██████████| 2/2 [00:00<00:00,  4.70it/s]


train_mean_scores 0.2909719168174995
valid_mean_scores 0.17630802161184894


train-loss: 0.6503431484103203: 100%|██████████| 160/160 [00:27<00:00,  5.88it/s]


Train-Epoch 2 Loss: 0.6503431484103203


valid-loss: 0.6584068238735199: 100%|██████████| 2/2 [00:00<00:00,  4.60it/s]


train_mean_scores 0.413177536147729
valid_mean_scores 0.2714358091317792


train-loss: 0.641777117177844: 100%|██████████| 160/160 [00:27<00:00,  5.80it/s] 


Train-Epoch 3 Loss: 0.641777117177844


valid-loss: 0.5857354700565338: 100%|██████████| 2/2 [00:00<00:00,  4.35it/s]


train_mean_scores 0.38089948448953503
valid_mean_scores 0.3044104559188261


train-loss: 0.6460278309881687: 100%|██████████| 160/160 [00:27<00:00,  5.84it/s]


Train-Epoch 4 Loss: 0.6460278309881687


valid-loss: 0.5874885022640228: 100%|██████████| 2/2 [00:00<00:00,  4.40it/s]


train_mean_scores 0.4419841403148806
valid_mean_scores 0.3094361182814216


train-loss: 0.6269399546086788: 100%|██████████| 160/160 [00:27<00:00,  5.85it/s]


Train-Epoch 5 Loss: 0.6269399546086788


valid-loss: 0.8076568841934204: 100%|██████████| 2/2 [00:00<00:00,  4.37it/s]


train_mean_scores 0.35327628211892736
valid_mean_scores 0.23474530645313757


train-loss: 0.6313175696879625: 100%|██████████| 160/160 [00:27<00:00,  5.86it/s]


Train-Epoch 6 Loss: 0.6313175696879625


valid-loss: 0.7543404996395111: 100%|██████████| 2/2 [00:00<00:00,  4.15it/s]


train_mean_scores 0.3878271188690568
valid_mean_scores 0.2650371019882693


train-loss: 0.626746841520071: 100%|██████████| 160/160 [00:27<00:00,  5.84it/s] 


Train-Epoch 7 Loss: 0.626746841520071


valid-loss: 0.7125252783298492: 100%|██████████| 2/2 [00:00<00:00,  4.17it/s]


train_mean_scores 0.36607471581480444
valid_mean_scores 0.17852933776053742


train-loss: 0.6243808541446925: 100%|██████████| 160/160 [00:27<00:00,  5.89it/s]


Train-Epoch 8 Loss: 0.6243808541446925


valid-loss: 0.6196556985378265: 100%|██████████| 2/2 [00:00<00:00,  3.98it/s]


train_mean_scores 0.43505534459901635
valid_mean_scores 0.2755592905490136


train-loss: 0.6164589840918779: 100%|██████████| 160/160 [00:27<00:00,  5.88it/s]


Train-Epoch 9 Loss: 0.6164589840918779


valid-loss: 0.702509880065918: 100%|██████████| 2/2 [00:00<00:00,  4.13it/s]


train_mean_scores 0.468114774754151
valid_mean_scores 0.267991914912889


train-loss: 0.6169571381062269: 100%|██████████| 160/160 [00:27<00:00,  5.88it/s]


Train-Epoch 10 Loss: 0.6169571381062269


valid-loss: 0.6272159516811371: 100%|██████████| 2/2 [00:00<00:00,  3.95it/s]


train_mean_scores 0.3815324864018538
valid_mean_scores 0.2838592789537585


train-loss: 0.6126565534621478: 100%|██████████| 160/160 [00:27<00:00,  5.86it/s]


Train-Epoch 11 Loss: 0.6126565534621478


valid-loss: 0.7232973277568817: 100%|██████████| 2/2 [00:00<00:00,  3.99it/s]


train_mean_scores 0.4497109411980954
valid_mean_scores 0.2505639468236194


train-loss: 0.6019140668213367: 100%|██████████| 160/160 [00:27<00:00,  5.88it/s]


Train-Epoch 12 Loss: 0.6019140668213367


valid-loss: 0.8209357857704163: 100%|██████████| 2/2 [00:00<00:00,  3.97it/s]


train_mean_scores 0.45557478106324006
valid_mean_scores 0.2417485390573002


train-loss: 0.611254358664155: 100%|██████████| 160/160 [00:27<00:00,  5.87it/s] 


Train-Epoch 13 Loss: 0.611254358664155


valid-loss: 0.5592108070850372: 100%|██████████| 2/2 [00:00<00:00,  4.15it/s]


train_mean_scores 0.4061351612135615
valid_mean_scores 0.27694564441150316


train-loss: 0.5896763946861029: 100%|██████████| 160/160 [00:27<00:00,  5.87it/s]


Train-Epoch 14 Loss: 0.5896763946861029


valid-loss: 0.6091294884681702: 100%|██████████| 2/2 [00:00<00:00,  2.81it/s]


train_mean_scores 0.47590192319931746
valid_mean_scores 0.24182470971860864


train-loss: 0.6011559259146452: 100%|██████████| 160/160 [00:27<00:00,  5.85it/s]


Train-Epoch 15 Loss: 0.6011559259146452


valid-loss: 0.5358023047447205: 100%|██████████| 2/2 [00:00<00:00,  4.02it/s]


train_mean_scores 0.3594451926594888
valid_mean_scores 0.23305494909998603


train-loss: 0.6066225290298461: 100%|██████████| 160/160 [00:28<00:00,  5.55it/s]


Train-Epoch 16 Loss: 0.6066225290298461


valid-loss: 0.6519322395324707: 100%|██████████| 2/2 [00:00<00:00,  3.77it/s]


train_mean_scores 0.4104861904982734
valid_mean_scores 0.2662498443219685


train-loss: 0.5865477509796619: 100%|██████████| 160/160 [00:28<00:00,  5.58it/s]


Train-Epoch 17 Loss: 0.5865477509796619


valid-loss: 0.6764684021472931: 100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


train_mean_scores 0.4460090162018968
valid_mean_scores 0.2398921796596154


train-loss: 0.5882022596895695: 100%|██████████| 160/160 [00:28<00:00,  5.65it/s]


Train-Epoch 18 Loss: 0.5882022596895695


valid-loss: 0.6447301506996155: 100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


train_mean_scores 0.4183113761424676
valid_mean_scores 0.2521186962837463


train-loss: 0.5922363255172968: 100%|██████████| 160/160 [00:28<00:00,  5.66it/s]


Train-Epoch 19 Loss: 0.5922363255172968


valid-loss: 0.8068088591098785: 100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


train_mean_scores 0.4525745765142123
valid_mean_scores 0.281620776026448


train-loss: 0.5900773901492358: 100%|██████████| 160/160 [00:28<00:00,  5.70it/s]


Train-Epoch 20 Loss: 0.5900773901492358


valid-loss: 0.662912130355835: 100%|██████████| 2/2 [00:00<00:00,  3.54it/s] 


train_mean_scores 0.39113149781804557
valid_mean_scores 0.23237735503065096


train-loss: 0.610027626529336: 100%|██████████| 160/160 [00:28<00:00,  5.64it/s] 


Train-Epoch 21 Loss: 0.610027626529336


valid-loss: 0.7234686613082886: 100%|██████████| 2/2 [00:00<00:00,  3.67it/s]


train_mean_scores 0.4086337794609494
valid_mean_scores 0.31785401888779996


train-loss: 0.6162837032228708: 100%|██████████| 160/160 [00:28<00:00,  5.62it/s]


Train-Epoch 22 Loss: 0.6162837032228708


valid-loss: 0.6718621253967285: 100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


train_mean_scores 0.36324485333557266
valid_mean_scores 0.26672896000176793


train-loss: 0.5905090026557446: 100%|██████████| 160/160 [00:28<00:00,  5.64it/s]


Train-Epoch 23 Loss: 0.5905090026557446


valid-loss: 0.5924438536167145: 100%|██████████| 2/2 [00:00<00:00,  3.24it/s]


train_mean_scores 0.3918294704674851
valid_mean_scores 0.32806256055585825


train-loss: 0.5923319164663553: 100%|██████████| 160/160 [00:28<00:00,  5.68it/s]


Train-Epoch 24 Loss: 0.5923319164663553


valid-loss: 0.6769067645072937: 100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


train_mean_scores 0.4078045515414779
valid_mean_scores 0.2893914052576896


train-loss: 0.5796976141631603: 100%|██████████| 160/160 [00:28<00:00,  5.71it/s]


Train-Epoch 25 Loss: 0.5796976141631603


valid-loss: 0.6296872794628143: 100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


train_mean_scores 0.41334761386808394
valid_mean_scores 0.29020037930273646


train-loss: 0.5825201295316219: 100%|██████████| 160/160 [00:27<00:00,  5.72it/s]


Train-Epoch 26 Loss: 0.5825201295316219


valid-loss: 0.6584209501743317: 100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


train_mean_scores 0.37549759255389514
valid_mean_scores 0.31460538455597853


train-loss: 0.5764151975512505: 100%|██████████| 160/160 [00:28<00:00,  5.52it/s]


Train-Epoch 27 Loss: 0.5764151975512505


valid-loss: 0.6779402792453766: 100%|██████████| 2/2 [00:00<00:00,  3.75it/s]


train_mean_scores 0.4660814378389194
valid_mean_scores 0.2882971950215917


train-loss: 0.5762191627174615: 100%|██████████| 160/160 [00:27<00:00,  5.72it/s]


Train-Epoch 28 Loss: 0.5762191627174615


valid-loss: 0.7152025401592255: 100%|██████████| 2/2 [00:00<00:00,  3.77it/s]


train_mean_scores 0.3178992190820953
valid_mean_scores 0.2623939748830546


train-loss: 0.5851696975529194: 100%|██████████| 160/160 [00:27<00:00,  5.85it/s]


Train-Epoch 29 Loss: 0.5851696975529194


valid-loss: 0.7009466886520386: 100%|██████████| 2/2 [00:00<00:00,  3.64it/s]


train_mean_scores 0.3895947596336544
valid_mean_scores 0.28095387107692577


train-loss: 0.5681707613170147: 100%|██████████| 160/160 [00:27<00:00,  5.87it/s]


Train-Epoch 30 Loss: 0.5681707613170147


valid-loss: 0.6967669725418091: 100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


train_mean_scores 0.46148443006532547
valid_mean_scores 0.34108619777023463


train-loss: 0.5693172518163919: 100%|██████████| 160/160 [00:27<00:00,  5.77it/s]


Train-Epoch 31 Loss: 0.5693172518163919


valid-loss: 0.6021804809570312: 100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


train_mean_scores 0.44429339715859034
valid_mean_scores 0.37156751370036845


train-loss: 0.5551506791263818: 100%|██████████| 160/160 [00:28<00:00,  5.70it/s]


Train-Epoch 32 Loss: 0.5551506791263818


valid-loss: 0.6431068778038025: 100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


train_mean_scores 0.4699852533912471
valid_mean_scores 0.37360419913158527


train-loss: 0.5475201938301325: 100%|██████████| 160/160 [00:27<00:00,  5.83it/s]


Train-Epoch 33 Loss: 0.5475201938301325


valid-loss: 0.660336434841156: 100%|██████████| 2/2 [00:00<00:00,  3.55it/s] 


train_mean_scores 0.4950616151453101
valid_mean_scores 0.3081555227752411


train-loss: 0.5457527417689562: 100%|██████████| 160/160 [00:27<00:00,  5.78it/s]


Train-Epoch 34 Loss: 0.5457527417689562


valid-loss: 0.5883988440036774: 100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


train_mean_scores 0.48178811815463957
valid_mean_scores 0.29598249809793


train-loss: 0.5383842434734106: 100%|██████████| 160/160 [00:27<00:00,  5.83it/s]


Train-Epoch 35 Loss: 0.5383842434734106


valid-loss: 0.7058611214160919: 100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


train_mean_scores 0.4074873186913889
valid_mean_scores 0.31240801912337046


train-loss: 0.5392710529267788: 100%|██████████| 160/160 [00:27<00:00,  5.81it/s]


Train-Epoch 36 Loss: 0.5392710529267788


valid-loss: 0.5891002118587494: 100%|██████████| 2/2 [00:01<00:00,  1.91it/s]


train_mean_scores 0.4437703093245798
valid_mean_scores 0.3491400201671331


train-loss: 0.5395819187164307: 100%|██████████| 160/160 [00:27<00:00,  5.84it/s]


Train-Epoch 37 Loss: 0.5395819187164307


valid-loss: 0.5709777474403381: 100%|██████████| 2/2 [00:00<00:00,  3.24it/s]
