In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

# valid_dataset = EziiDataset(
#     exp_names=CFG.valid_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.valid_zarr_types,
#     augmentation=False,
#     train=True,
# )

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 250/250 [00:02<00:00, 103.45it/s]
  0%|          | 0/125 [00:01<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [3]:
from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

  0%|          | 0/125 [00:00<?, ?it/s]


torch.Size([2, 16, 315, 315])

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
model.load_state_dict(torch.load("./pretrained_model.pth"))

<All keys matched successfully>

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [8]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [9]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [10]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([2, 16, 320, 320])

In [11]:
def last_padding(tomogram, slice_size):
    # tomogram: (tensor)
    b, d, h, w = tomogram.shape
    last_padding = slice_size - d % slice_size
    if last_padding == slice_size:
        return tomogram
    else:
        return torch.cat(
            [tomogram, torch.zeros(b, last_padding, h, w).to(tomogram.device)], dim=1
        )


def inference(model, exp_name, train=True):
    dataset = EziiDataset(
        exp_names=[exp_name],
        base_dir="../../inputs/train/",
        particles_name=CFG.particles_name,
        resolution=CFG.resolution,
        zarr_type=["denoised"],
        train=train,
        slice=False,
    )
    res_array = CFG.original_img_shape[CFG.resolution]
    pred_array = np.zeros(
        (len(CFG.particles_name) + 1, res_array[0], res_array[1], res_array[2])
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)
    model.eval()
    # tq = tqdm(loader)
    for data in loader:  # 実験データ1つを取り出す
        for i in range(0, data["normalized_tomogram"].shape[1], CFG.slice_):
            normalized_tomogram = data["normalized_tomogram"][:, i : i + CFG.slice_]
            normalized_tomogram = last_padding(normalized_tomogram, CFG.slice_)
            normalized_tomogram = padf(normalized_tomogram)
            normalized_tomogram = preprocess_tensor(normalized_tomogram).to("cuda")
            pred = model(normalized_tomogram)
            prob_pred = (
                torch.softmax(pred, dim=1).detach().cpu().numpy()
            )  # torch.Size([1, 7, 32, 320, 320])
            range_ = min(i + CFG.slice_, res_array[0])
            hw_pad_diff = prob_pred.shape[-1] - res_array[-1]

            if i >= res_array[0]:
                continue

            if range_ == res_array[0]:
                pred_array[:, i:range_] += prob_pred[
                    0, :, : res_array[0] - i, :-hw_pad_diff, :-hw_pad_diff
                ]
            else:
                pred_array[:, i:range_] += prob_pred[
                    0, :, :range_, :-hw_pad_diff, :-hw_pad_diff
                ]
            # tq.update()
    # tq.close()

    return pred_array  # (7, 92, 315, 315)


inferenced_array = inference(model, "TS_6_6", train=False)

In [12]:
best_model = None
best_score = -100

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    tq = tqdm(train_loader)
    for data in train_loader:
        normalized_tomogram = data["normalized_tomogram"].cuda()
        segmentation_map = data["segmentation_map"].long().cuda()

        normalized_tomogram = padf(normalized_tomogram)
        segmentation_map = padf(segmentation_map)

        optimizer.zero_grad()
        pred = model(preprocess_tensor(normalized_tomogram))
        loss = seg_loss(pred, segmentation_map)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        # 確率予測
        prob_pred = torch.softmax(pred, dim=1)
        tq.update()
        tq.set_description(f"train-loss: {np.mean(train_loss)}")
    tq.close()

    print(f"Epoch {epoch} Loss: {np.mean(train_loss)}")

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # # for exp_name in tqdm(CFG.train_exp_names):
    # for exp_name in tqdm(CFG.train_exp_names[:5]):  # 5つのデータで試す
    #     inferenced_array = inference(model, exp_name, train=False)
    #     train_nshuffle_pred_tomogram[exp_name] = inferenced_array

    #     mean_score, scores = visualize_epoch_results(
    #         train_nshuffle_pred_tomogram,
    #         base_dir="../../inputs/train/overlay/ExperimentRuns/",
    #         sikii_dict=CFG.initial_sikii,
    #     )
    #     train_mean_scores.append(mean_score)
    # print("train_mean_scores", np.mean(train_mean_scores))

    # for exp_name in tqdm(CFG.valid_exp_names):
    #     inferenced_array = inference(model, exp_name, train=False)
    #     valid_pred_tomogram[exp_name] = inferenced_array

    #     mean_score, scores = visualize_epoch_results(
    #         valid_pred_tomogram,
    #         base_dir="../../inputs/train/overlay/ExperimentRuns/",
    #         sikii_dict=CFG.initial_sikii,
    #     )
    #     valid_mean_scores.append(mean_score)
    # print("valid_mean_scores", np.mean(valid_mean_scores))

train-loss: 0.12080154132843017: 100%|██████████| 125/125 [00:32<00:00,  3.81it/s]


Epoch 0 Loss: 0.12080154132843017


train-loss: 0.13010614442825316: 100%|██████████| 125/125 [00:32<00:00,  3.79it/s]


Epoch 1 Loss: 0.13010614442825316


train-loss: 0.1383936710357666: 100%|██████████| 125/125 [00:34<00:00,  3.65it/s] 


Epoch 2 Loss: 0.1383936710357666


train-loss: 0.13116896390914917: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 3 Loss: 0.13116896390914917


train-loss: 0.12881325149536133: 100%|██████████| 125/125 [00:32<00:00,  3.86it/s]


Epoch 4 Loss: 0.12881325149536133


train-loss: 0.13108581113815307: 100%|██████████| 125/125 [00:33<00:00,  3.76it/s]


Epoch 5 Loss: 0.13108581113815307


train-loss: 0.12541315698623656: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 6 Loss: 0.12541315698623656


train-loss: 0.11915369939804077: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 7 Loss: 0.11915369939804077


train-loss: 0.12315077590942383: 100%|██████████| 125/125 [00:34<00:00,  3.65it/s]


Epoch 8 Loss: 0.12315077590942383


train-loss: 0.11211575651168823: 100%|██████████| 125/125 [00:33<00:00,  3.75it/s]


Epoch 9 Loss: 0.11211575651168823


train-loss: 0.10848549222946167: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 10 Loss: 0.10848549222946167


train-loss: 0.11447711563110352: 100%|██████████| 125/125 [00:34<00:00,  3.67it/s]


Epoch 11 Loss: 0.11447711563110352


train-loss: 0.11322955560684204: 100%|██████████| 125/125 [00:31<00:00,  3.94it/s]


Epoch 12 Loss: 0.11322955560684204


train-loss: 0.09843223905563354: 100%|██████████| 125/125 [00:31<00:00,  3.94it/s]


Epoch 13 Loss: 0.09843223905563354


train-loss: 0.11948630285263062: 100%|██████████| 125/125 [00:34<00:00,  3.66it/s]


Epoch 14 Loss: 0.11948630285263062


train-loss: 0.10657599020004273: 100%|██████████| 125/125 [00:32<00:00,  3.81it/s]


Epoch 15 Loss: 0.10657599020004273


train-loss: 0.10041883707046509: 100%|██████████| 125/125 [00:32<00:00,  3.85it/s]


Epoch 16 Loss: 0.10041883707046509


train-loss: 0.1069192419052124: 100%|██████████| 125/125 [00:34<00:00,  3.67it/s] 


Epoch 17 Loss: 0.1069192419052124


train-loss: 0.1033318281173706: 100%|██████████| 125/125 [00:32<00:00,  3.82it/s] 


Epoch 18 Loss: 0.1033318281173706


train-loss: 0.11023575353622436: 100%|██████████| 125/125 [00:33<00:00,  3.79it/s]


Epoch 19 Loss: 0.11023575353622436


train-loss: 0.104598699092865: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]  


Epoch 20 Loss: 0.104598699092865


train-loss: 0.10701517963409424: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 21 Loss: 0.10701517963409424


train-loss: 0.11581902027130127: 100%|██████████| 125/125 [00:31<00:00,  3.95it/s]


Epoch 22 Loss: 0.11581902027130127


train-loss: 0.12231084632873535: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 23 Loss: 0.12231084632873535


train-loss: 0.10904725122451782: 100%|██████████| 125/125 [00:34<00:00,  3.66it/s]


Epoch 24 Loss: 0.10904725122451782


train-loss: 0.10943134021759034: 100%|██████████| 125/125 [00:32<00:00,  3.90it/s]


Epoch 25 Loss: 0.10943134021759034


train-loss: 0.11105205726623535: 100%|██████████| 125/125 [00:33<00:00,  3.76it/s]


Epoch 26 Loss: 0.11105205726623535


train-loss: 0.10127256536483764: 100%|██████████| 125/125 [00:32<00:00,  3.86it/s]


Epoch 27 Loss: 0.10127256536483764


train-loss: 0.11806047534942626: 100%|██████████| 125/125 [00:32<00:00,  3.85it/s]


Epoch 28 Loss: 0.11806047534942626


train-loss: 0.10520601892471314: 100%|██████████| 125/125 [00:31<00:00,  3.97it/s]


Epoch 29 Loss: 0.10520601892471314


train-loss: 0.10184221935272217: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 30 Loss: 0.10184221935272217


train-loss: 0.10504046440124512: 100%|██████████| 125/125 [00:33<00:00,  3.72it/s]


Epoch 31 Loss: 0.10504046440124512


train-loss: 0.10448912143707276: 100%|██████████| 125/125 [00:32<00:00,  3.81it/s]


Epoch 32 Loss: 0.10448912143707276


train-loss: 0.1030221700668335: 100%|██████████| 125/125 [00:31<00:00,  3.92it/s] 


Epoch 33 Loss: 0.1030221700668335


train-loss: 0.10545083618164063: 100%|██████████| 125/125 [00:32<00:00,  3.86it/s]


Epoch 34 Loss: 0.10545083618164063


train-loss: 0.10106189680099488: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 35 Loss: 0.10106189680099488


train-loss: 0.10335549926757813: 100%|██████████| 125/125 [00:33<00:00,  3.71it/s]


Epoch 36 Loss: 0.10335549926757813


train-loss: 0.10049703693389893: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 37 Loss: 0.10049703693389893


train-loss: 0.09026354265213013: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 38 Loss: 0.09026354265213013


train-loss: 0.09331433725357055: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 39 Loss: 0.09331433725357055


train-loss: 0.09550884771347046: 100%|██████████| 125/125 [00:32<00:00,  3.84it/s]


Epoch 40 Loss: 0.09550884771347046


train-loss: 0.09225715112686157: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 41 Loss: 0.09225715112686157


train-loss: 0.09215547227859497: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s]


Epoch 42 Loss: 0.09215547227859497


train-loss: 0.09066497135162353: 100%|██████████| 125/125 [00:31<00:00,  3.91it/s]


Epoch 43 Loss: 0.09066497135162353


train-loss: 0.09448057270050049: 100%|██████████| 125/125 [00:31<00:00,  3.99it/s]


Epoch 44 Loss: 0.09448057270050049


train-loss: 0.08747357559204101: 100%|██████████| 125/125 [00:32<00:00,  3.85it/s]


Epoch 45 Loss: 0.08747357559204101


train-loss: 0.09309484195709229: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 46 Loss: 0.09309484195709229


train-loss: 0.08822898483276367: 100%|██████████| 125/125 [00:32<00:00,  3.81it/s]


Epoch 47 Loss: 0.08822898483276367


train-loss: 0.0914212555885315: 100%|██████████| 125/125 [00:34<00:00,  3.61it/s] 


Epoch 48 Loss: 0.0914212555885315


train-loss: 0.09125324630737305: 100%|██████████| 125/125 [00:31<00:00,  3.94it/s]


Epoch 49 Loss: 0.09125324630737305


train-loss: 0.09233773851394653: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 50 Loss: 0.09233773851394653


train-loss: 0.09318151903152466: 100%|██████████| 125/125 [00:32<00:00,  3.79it/s]


Epoch 51 Loss: 0.09318151903152466


train-loss: 0.08908094120025635: 100%|██████████| 125/125 [00:32<00:00,  3.84it/s]


Epoch 52 Loss: 0.08908094120025635


train-loss: 0.0885435996055603: 100%|██████████| 125/125 [00:32<00:00,  3.80it/s] 


Epoch 53 Loss: 0.0885435996055603


train-loss: 0.09535950136184693: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 54 Loss: 0.09535950136184693


train-loss: 0.089511962890625: 100%|██████████| 125/125 [00:32<00:00,  3.90it/s]  


Epoch 55 Loss: 0.089511962890625


train-loss: 0.08681473398208618: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 56 Loss: 0.08681473398208618


train-loss: 0.08686972188949585: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 57 Loss: 0.08686972188949585


train-loss: 0.08726723480224609: 100%|██████████| 125/125 [00:32<00:00,  3.81it/s]


Epoch 58 Loss: 0.08726723480224609


train-loss: 0.08195358753204346: 100%|██████████| 125/125 [00:33<00:00,  3.68it/s]


Epoch 59 Loss: 0.08195358753204346


train-loss: 0.08491843938827515: 100%|██████████| 125/125 [00:31<00:00,  3.93it/s]


Epoch 60 Loss: 0.08491843938827515


train-loss: 0.0832944655418396: 100%|██████████| 125/125 [00:32<00:00,  3.86it/s] 


Epoch 61 Loss: 0.0832944655418396


train-loss: 0.08637713718414307: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 62 Loss: 0.08637713718414307


train-loss: 0.08406267499923706: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 63 Loss: 0.08406267499923706


train-loss: 0.08402868223190307: 100%|██████████| 125/125 [00:32<00:00,  3.85it/s]


Epoch 64 Loss: 0.08402868223190307


train-loss: 0.08815273809432983: 100%|██████████| 125/125 [00:32<00:00,  3.83it/s]


Epoch 65 Loss: 0.08815273809432983


train-loss: 0.08644706773757935: 100%|██████████| 125/125 [00:32<00:00,  3.86it/s]


Epoch 66 Loss: 0.08644706773757935


train-loss: 0.09046861314773559: 100%|██████████| 125/125 [00:32<00:00,  3.88it/s]


Epoch 67 Loss: 0.09046861314773559


train-loss: 0.08641238451004028: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 68 Loss: 0.08641238451004028


train-loss: 0.08621704006195069: 100%|██████████| 125/125 [00:33<00:00,  3.77it/s]


Epoch 69 Loss: 0.08621704006195069


train-loss: 0.08536011409759521: 100%|██████████| 125/125 [00:33<00:00,  3.72it/s]


Epoch 70 Loss: 0.08536011409759521


train-loss: 0.0856309027671814: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s] 


Epoch 71 Loss: 0.0856309027671814


train-loss: 0.0847681450843811: 100%|██████████| 125/125 [00:32<00:00,  3.89it/s] 


Epoch 72 Loss: 0.0847681450843811


train-loss: 0.09012025499343872: 100%|██████████| 125/125 [00:32<00:00,  3.83it/s]


Epoch 73 Loss: 0.09012025499343872


train-loss: 0.08152238655090333: 100%|██████████| 125/125 [00:34<00:00,  3.60it/s]


Epoch 74 Loss: 0.08152238655090333


train-loss: 0.08520040988922119: 100%|██████████| 125/125 [00:28<00:00,  4.36it/s]


Epoch 75 Loss: 0.08520040988922119


train-loss: 0.08137386035919189: 100%|██████████| 125/125 [00:28<00:00,  4.40it/s]


Epoch 76 Loss: 0.08137386035919189


train-loss: 0.08322919225692749: 100%|██████████| 125/125 [00:28<00:00,  4.33it/s]


Epoch 77 Loss: 0.08322919225692749


train-loss: 0.08321789884567261: 100%|██████████| 125/125 [00:29<00:00,  4.19it/s]


Epoch 78 Loss: 0.08321789884567261


train-loss: 0.08061753559112549: 100%|██████████| 125/125 [00:31<00:00,  3.94it/s]


Epoch 79 Loss: 0.08061753559112549


train-loss: 0.08283480834960938: 100%|██████████| 125/125 [00:32<00:00,  3.88it/s]


Epoch 80 Loss: 0.08283480834960938


train-loss: 0.08436234521865844: 100%|██████████| 125/125 [00:31<00:00,  3.92it/s]


Epoch 81 Loss: 0.08436234521865844


train-loss: 0.08688085508346557: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 82 Loss: 0.08688085508346557


train-loss: 0.09075843191146851: 100%|██████████| 125/125 [00:33<00:00,  3.76it/s]


Epoch 83 Loss: 0.09075843191146851


train-loss: 0.07669441509246826: 100%|██████████| 125/125 [00:31<00:00,  3.91it/s]


Epoch 84 Loss: 0.07669441509246826


train-loss: 0.08557446813583373: 100%|██████████| 125/125 [00:32<00:00,  3.90it/s]


Epoch 85 Loss: 0.08557446813583373


train-loss: 0.07908355760574341: 100%|██████████| 125/125 [00:34<00:00,  3.67it/s]


Epoch 86 Loss: 0.07908355760574341


train-loss: 0.07860471487045288: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 87 Loss: 0.07860471487045288


train-loss: 0.0780164303779602: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s] 


Epoch 88 Loss: 0.0780164303779602


train-loss: 0.07291342973709107: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 89 Loss: 0.07291342973709107


train-loss: 0.0913273115158081: 100%|██████████| 125/125 [00:33<00:00,  3.75it/s] 


Epoch 90 Loss: 0.0913273115158081


train-loss: 0.08197513389587402: 100%|██████████| 125/125 [00:31<00:00,  3.91it/s]


Epoch 91 Loss: 0.08197513389587402


train-loss: 0.08220087766647338: 100%|██████████| 125/125 [00:31<00:00,  3.97it/s]


Epoch 92 Loss: 0.08220087766647338


train-loss: 0.08313231658935546: 100%|██████████| 125/125 [00:33<00:00,  3.74it/s]


Epoch 93 Loss: 0.08313231658935546


train-loss: 0.08175555753707886: 100%|██████████| 125/125 [00:31<00:00,  3.92it/s]


Epoch 94 Loss: 0.08175555753707886


train-loss: 0.08528525400161743: 100%|██████████| 125/125 [00:32<00:00,  3.87it/s]


Epoch 95 Loss: 0.08528525400161743


train-loss: 0.07733781862258911: 100%|██████████| 125/125 [00:33<00:00,  3.78it/s]


Epoch 96 Loss: 0.07733781862258911


train-loss: 0.08058542919158936: 100%|██████████| 125/125 [00:32<00:00,  3.84it/s]


Epoch 97 Loss: 0.08058542919158936


train-loss: 0.08752907848358155: 100%|██████████| 125/125 [00:31<00:00,  3.93it/s]


Epoch 98 Loss: 0.08752907848358155


train-loss: 0.08515898084640502: 100%|██████████| 125/125 [00:31<00:00,  3.98it/s]

Epoch 99 Loss: 0.08515898084640502





In [15]:
# モデルの保存
torch.save(model.state_dict(), "./pretrained_model.pth")

# ############### validation ################
train_nshuffle_original_tomogram = defaultdict(list)
train_nshuffle_pred_tomogram = defaultdict(list)
train_nshuffle_gt_tomogram = defaultdict(list)

valid_original_tomogram = defaultdict(list)
valid_pred_tomogram = defaultdict(list)
valid_gt_tomogram = defaultdict(list)

train_mean_scores = []
valid_mean_scores = []

# for exp_name in tqdm(CFG.train_exp_names):
for exp_name in tqdm(CFG.train_exp_names):  # 5つのデータで試す
    inferenced_array = inference(model, exp_name, train=False)
    train_nshuffle_pred_tomogram[exp_name] = inferenced_array

    mean_score, scores = visualize_epoch_results(
        train_nshuffle_pred_tomogram,
        base_dir="../../inputs/train/overlay/ExperimentRuns/",
        sikii_dict=CFG.initial_sikii,
    )
    train_mean_scores.append(mean_score)
print("train_mean_scores", np.mean(train_mean_scores))

for exp_name in tqdm(CFG.valid_exp_names):
    inferenced_array = inference(model, exp_name, train=False)
    valid_pred_tomogram[exp_name] = inferenced_array

    mean_score, scores = visualize_epoch_results(
        valid_pred_tomogram,
        base_dir="../../inputs/train/overlay/ExperimentRuns/",
        sikii_dict=CFG.initial_sikii,
    )
    valid_mean_scores.append(mean_score)
print("valid_mean_scores", np.mean(valid_mean_scores))

100%|██████████| 5/5 [00:20<00:00,  4.09s/it]


train_mean_scores 0.06426814960652333


100%|██████████| 2/2 [00:05<00:00,  2.52s/it]

valid_mean_scores 0.0





In [14]:
np.unique(train_nshuffle_pred_tomogram["TS_5_4"].argmax(0), return_counts=True)

(array([0, 1, 3, 4, 5, 6]),
 array([9021801,    4254,    3866,   52987,   32489,   13303]))