In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from src.inference import inference, inference2pos, create_gt_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    augmentation=False,
    train=True,
    slice=True,
    pre_read=True,
)

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=CFG.num_workers,
)

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 1000/1000 [00:14<00:00, 70.99it/s]
100%|██████████| 2/2 [00:05<00:00,  2.72s/it]
  0%|          | 0/500 [00:01<?, ?it/s]


torch.Size([2, 16, 630, 630])

In [3]:
# from tqdm import tqdm

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CFG.batch_size,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# # train_nshuffle_loader = DataLoader(
# #     train_nshuffle_dataset,
# #     batch_size=1,
# #     shuffle=True,
# #     drop_last=True,
# #     pin_memory=True,
# #     num_workers=CFG.num_workers,
# # )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

# for data in tqdm(train_loader):
#     normalized_tomogram = data["normalized_tomogram"]
#     segmentation_map = data["segmentation_map"]
#     break

# normalized_tomogram.shape

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
# model.load_state_dict(torch.load("./pretrained_model.pth"))

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [8]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [9]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [10]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([2, 16, 640, 640])

In [None]:
best_model = None
best_score = -100

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    valid_loss = []
    tq = tqdm(train_loader)
    for data in train_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        optimizer.zero_grad()
        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"train-loss: {np.mean(train_loss)}")
    tq.close()

    print(f"Train-Epoch {epoch} Loss: {np.mean(train_loss)}")

    tq = tqdm(valid_loader)
    for data in valid_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        valid_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"valid-loss: {np.mean(valid_loss)}")
    tq.close()

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # モデルの保存
    torch.save(model.state_dict(), "./pretrained_model.pth")

    # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # for exp_name in tqdm(CFG.train_exp_names):
    for exp_name in CFG.train_exp_names[:5]:  # 5つのデータで試す
        # inferenced_array = inference(model, exp_name, train=False)
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        train_nshuffle_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        train_mean_scores.append(score_)

    print("train_mean_scores", np.mean(train_mean_scores))

    for exp_name in CFG.valid_exp_names:
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        valid_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        valid_mean_scores.append(score_)

    print("valid_mean_scores", np.mean(valid_mean_scores))

    if np.mean(valid_mean_scores) > best_score:
        best_score = np.mean(valid_mean_scores)
        best_model = model.state_dict()
        torch.save(best_model, f"./best_model.pth")

train-loss: 0.8483525946140289: 100%|██████████| 500/500 [06:58<00:00,  1.20it/s]


Train-Epoch 0 Loss: 0.8483525946140289


valid-loss: 0.8119079172611237: 100%|██████████| 2/2 [00:00<00:00,  4.19it/s]


train_mean_scores 0.13730131035477208
valid_mean_scores 0.09112389015854651


train-loss: 0.617191806435585: 100%|██████████| 500/500 [4:48:19<00:00, 34.60s/it]    


Train-Epoch 1 Loss: 0.617191806435585


valid-loss: 0.7200845777988434: 100%|██████████| 2/2 [00:03<00:00,  1.73s/it]


train_mean_scores 0.35795503439246007
valid_mean_scores 0.20938587259521776


train-loss: 0.5339929614067077: 100%|██████████| 500/500 [1:06:26<00:00,  7.97s/it] 


Train-Epoch 2 Loss: 0.5339929614067077


valid-loss: 0.5607078671455383: 100%|██████████| 2/2 [00:02<00:00,  1.34s/it]


train_mean_scores 0.5145757531650982
valid_mean_scores 0.3032311088948433


train-loss: 0.5103123804330826: 100%|██████████| 500/500 [9:09:52<00:00, 65.99s/it]    


Train-Epoch 3 Loss: 0.5103123804330826


valid-loss: 0.8021272420883179: 100%|██████████| 2/2 [01:24<00:00, 42.16s/it]


train_mean_scores 0.5269588702537318
valid_mean_scores 0.29876999202066035


  0%|          | 0/500 [00:00<?, ?it/s]