In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from src.inference import inference, inference2pos, create_gt_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
    augmentation=True,
    slice=True,
    pre_read=True,
)

# train_nshuffle_dataset = EziiDataset(
#     exp_names=CFG.train_exp_names,
#     base_dir="../../inputs/train/",
#     particles_name=CFG.particles_name,
#     resolution=CFG.resolution,
#     zarr_type=CFG.train_zarr_types,
#     augmentation=False,
#     train=True,
# )

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    augmentation=False,
    train=True,
    slice=True,
    pre_read=True,
)

from tqdm import tqdm

train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=CFG.num_workers,
)
# train_nshuffle_loader = DataLoader(
#     train_nshuffle_dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=CFG.num_workers,
)

for data in tqdm(train_loader):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

normalized_tomogram.shape

100%|██████████| 600/600 [00:01<00:00, 319.01it/s]
100%|██████████| 2/2 [00:00<00:00,  2.89it/s]
  0%|          | 0/100 [00:01<?, ?it/s]


torch.Size([6, 16, 315, 315])

In [3]:
# from tqdm import tqdm

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=CFG.batch_size,
#     shuffle=True,
#     drop_last=True,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )
# # train_nshuffle_loader = DataLoader(
# #     train_nshuffle_dataset,
# #     batch_size=1,
# #     shuffle=True,
# #     drop_last=True,
# #     pin_memory=True,
# #     num_workers=CFG.num_workers,
# # )
# valid_loader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
#     num_workers=CFG.num_workers,
# )

# for data in tqdm(train_loader):
#     normalized_tomogram = data["normalized_tomogram"]
#     segmentation_map = data["segmentation_map"]
#     break

# normalized_tomogram.shape

In [4]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
# model.load_state_dict(torch.load("./pretrained_model.pth"))
# model.load_state_dict(torch.load("./best_model.pth"))

In [5]:
# # "encoder"と名のつくパラメータは学習しない
# for layer, param in model.named_parameters():
#     if "encoder" in layer:
#         param.requires_grad = False

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# サンプルデータ
num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


# カラーバー付きプロット
def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im

In [7]:
import torch
import random
import torchvision.transforms.functional as TF


# 回転
# 3Dテンソルの各軸に対して指定した角度で回転する関数
def rotate_3d(tomogram, segmentation_map, angle):
    """Rotates the 3D tensors tomogram and segmentation_map around the Z-axis."""
    rotated_tomogram = TF.rotate(tomogram, angle, expand=False)
    rotated_segmentation_map = TF.rotate(segmentation_map, angle, expand=False)
    return rotated_tomogram, rotated_segmentation_map


# 平行移動
# 指定された範囲でランダムに平行移動
def translate_3d(tomogram, segmentation_map, max_shift):
    """Translates the 3D tensors by a random shift within max_shift."""
    shift_x = random.randint(-max_shift, max_shift)
    shift_y = random.randint(-max_shift, max_shift)
    translated_tomogram = TF.affine(
        tomogram, angle=0, translate=(shift_x, shift_y), scale=1, shear=0
    )
    translated_segmentation_map = TF.affine(
        segmentation_map, angle=0, translate=(shift_x, shift_y), scale=1, shear=0
    )
    return translated_tomogram, translated_segmentation_map


# フリップ
# 縦横（上下左右）ランダムフリップ
def flip_3d(tomogram, segmentation_map):
    """Randomly flips the 3D tensors along height or width."""
    if random.random() > 0.5:  # Horizontal flip
        tomogram = torch.flip(tomogram, dims=[-1])
        segmentation_map = torch.flip(segmentation_map, dims=[-1])
    if random.random() > 0.5:  # Vertical flip
        tomogram = torch.flip(tomogram, dims=[-2])
        segmentation_map = torch.flip(segmentation_map, dims=[-2])
    return tomogram, segmentation_map


# クロッピング
# 入力テンソルを中心またはランダムクロップで切り取る
def crop_3d(tomogram, segmentation_map, crop_size):
    """Crops the 3D tensors to the specified crop_size."""
    _, depth, height, width = tomogram.size()
    crop_d, crop_h, crop_w = crop_size

    if crop_h > height or crop_w > width:
        raise ValueError("Crop size cannot be larger than the original size.")

    start_h = random.randint(0, height - crop_h)  # Random starting position for height
    start_w = random.randint(0, width - crop_w)  # Random starting position for width

    cropped_tomogram = tomogram[
        :, :, start_h : start_h + crop_h, start_w : start_w + crop_w
    ]
    cropped_segmentation_map = segmentation_map[
        :, :, start_h : start_h + crop_h, start_w : start_w + crop_w
    ]

    return cropped_tomogram, cropped_segmentation_map


# Mixup
# 2つのサンプルを線形補間して混合
def mixup(tomogram, segmentation_map, alpha=0.4):
    """Applies mixup augmentation to the batch."""
    lam = random.betavariate(alpha, alpha)
    batch_size = tomogram.size(0)
    index = torch.randperm(batch_size)

    mixed_tomogram = lam * tomogram + (1 - lam) * tomogram[index, :]
    mixed_segmentation_map = (
        lam * segmentation_map + (1 - lam) * segmentation_map[index, :]
    )

    return mixed_tomogram, mixed_segmentation_map


# Cutmix
# ランダム領域を切り取って別のサンプルに貼り付け
def cutmix(tomogram, segmentation_map, alpha=1.0):
    """Applies cutmix augmentation to the batch."""
    lam = random.betavariate(alpha, alpha)
    batch_size, depth, height, width = tomogram.size()
    index = torch.randperm(batch_size)

    cx = random.randint(0, width)
    cy = random.randint(0, height)
    cw = int(width * (1 - lam))
    ch = int(height * (1 - lam))

    x1 = max(cx - cw // 2, 0)
    x2 = min(cx + cw // 2, width)
    y1 = max(cy - ch // 2, 0)
    y2 = min(cy + ch // 2, height)

    tomogram[:, :, y1:y2, x1:x2] = tomogram[index, :, y1:y2, x1:x2]
    segmentation_map[:, :, y1:y2, x1:x2] = segmentation_map[index, :, y1:y2, x1:x2]

    return tomogram, segmentation_map


# データ拡張の組み合わせ適用
def augment_data(
    tomogram,
    segmentation_map,
    crop_size=(16, 256, 256),
    max_shift=10,
    rotation_angle=30,
    p=0.5,
    mixup_alpha=0.4,
    cutmix_alpha=1.0,
):
    """Applies a combination of rotation, translation, flipping, cropping, mixup, and cutmix to the inputs with probabilities."""
    if random.random() < p:
        tomogram, segmentation_map = rotate_3d(
            tomogram,
            segmentation_map,
            angle=random.uniform(-rotation_angle, rotation_angle),
        )
    if random.random() < p:
        tomogram, segmentation_map = translate_3d(
            tomogram, segmentation_map, max_shift=max_shift
        )
    if random.random() < p:
        tomogram, segmentation_map = flip_3d(tomogram, segmentation_map)
    if random.random() < p:
        tomogram, segmentation_map = crop_3d(
            tomogram, segmentation_map, crop_size=crop_size
        )
    if random.random() < p:
        tomogram, segmentation_map = mixup(
            tomogram, segmentation_map, alpha=mixup_alpha
        )
    # if random.random() < p:
    #     tomogram, segmentation_map = cutmix(
    #         tomogram, segmentation_map, alpha=cutmix_alpha
    #     )
    return tomogram, segmentation_map


# 使用例
# バッチサイズ6, 深さ16, 高さ320, 幅320のランダムテンソル
tomogram = torch.rand((6, 16, 320, 320))
segmentation_map = torch.randint(0, 2, (6, 16, 320, 320))  # ラベルは0または1

# データ拡張の適用
aug_tomogram, aug_segmentation_map = augment_data(tomogram, segmentation_map, p=0.7)
print("Original shape:", tomogram.shape)
print("Augmented shape:", aug_tomogram.shape)

Original shape: torch.Size([6, 16, 320, 320])
Augmented shape: torch.Size([6, 16, 256, 256])


In [8]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.Adam(
    model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
# criterion = nn.CrossEntropyLoss(
#     #  weight=torch.tensor([2.0, 32, 32, 32, 32, 32, 32]).to("cuda")
# )
criterion = DiceLoss()
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=CFG.epochs * len(train_loader),
    # * batch_size,
)
seg_loss = SegmentationLoss(criterion)
padf = PadToSize(CFG.resolution)

In [9]:
# b, c, d, h, w = CFG.batch_size, 1, 96, 320, 320

In [10]:
def preprocess_tensor(tensor):
    batch_size, depth, height, width = tensor.shape
    tensor = tensor.unsqueeze(2)  # (b, d, h, w) -> (b, d, 1, h, w)
    return tensor

In [11]:
padf = PadToSize(CFG.resolution)
padf(normalized_tomogram).shape

torch.Size([6, 16, 320, 320])

In [None]:
best_model = None
best_score = -100

grand_train_loss = []
grand_valid_loss = []
grand_train_score = []
grand_valid_score = []

for epoch in range(CFG.epochs):
    model.train()
    train_loss = []
    valid_loss = []
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{CFG.epochs} [Training]") as tq:
        for data in tq:
            normalized_tomogram = data["normalized_tomogram"]
            segmentation_map = data["segmentation_map"]

            normalized_tomogram = padf(normalized_tomogram)
            segmentation_map = padf(segmentation_map)

            # データ拡張
            normalized_tomogram, segmentation_map = augment_data(
                normalized_tomogram, segmentation_map, p=CFG.augmentation_prob
            )
            normalized_tomogram = normalized_tomogram.cuda()
            segmentation_map = segmentation_map.long().cuda()

            optimizer.zero_grad()
            pred = model(preprocess_tensor(normalized_tomogram))
            loss = seg_loss(pred, segmentation_map)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss.append(loss.item())

            # 確率予測
            prob_pred = torch.softmax(pred, dim=1)
            tq.set_postfix({"loss": f"{np.mean(train_loss):.4f}"})

    with tqdm(valid_loader, desc=f"Epoch {epoch + 1}/{CFG.epochs} [Validation]") as tq:
        for data in tq:
            normalized_tomogram = data["normalized_tomogram"].cuda()
            segmentation_map = data["segmentation_map"].long().cuda()

            normalized_tomogram = padf(normalized_tomogram)
            segmentation_map = padf(segmentation_map)

            pred = model(preprocess_tensor(normalized_tomogram))
            loss = seg_loss(pred, segmentation_map)
            valid_loss.append(loss.item())

            # 確率予測
            prob_pred = torch.softmax(pred, dim=1)
            tq.set_postfix({"loss": f"{np.mean(valid_loss):.4f}"})

    # # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # モデルの保存
    torch.save(model.state_dict(), "./pretrained_model.pth")

    # ############### validation ################
    train_nshuffle_original_tomogram = defaultdict(list)
    train_nshuffle_pred_tomogram = defaultdict(list)
    train_nshuffle_gt_tomogram = defaultdict(list)

    valid_original_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)

    train_mean_scores = []
    valid_mean_scores = []

    # for exp_name in tqdm(CFG.train_exp_names):
    for exp_name in CFG.train_exp_names[:5]:  # 5つのデータで試す
        # inferenced_array = inference(model, exp_name, train=False)
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        train_nshuffle_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        train_mean_scores.append(score_)

    # print("train_mean_scores", np.mean(train_mean_scores))

    for exp_name in CFG.valid_exp_names:
        inferenced_array, n_tomogram, segmentation_map = inference(
            model, exp_name, train=True
        )
        pred_df = inference2pos(
            pred_segmask=inferenced_array.argmax(0), exp_name=exp_name
        )
        base_dir = "../../inputs/train/overlay/ExperimentRuns/"
        gt_df = create_gt_df(base_dir, [exp_name])

        valid_pred_tomogram[exp_name] = inferenced_array

        score_ = score(
            pred_df, gt_df, row_id_column_name="index", distance_multiplier=1.0, beta=4
        )
        valid_mean_scores.append(score_)

    # print("valid_mean_scores", np.mean(valid_mean_scores))

    if np.mean(valid_mean_scores) > best_score:
        best_score = np.mean(valid_mean_scores)
        best_model = model.state_dict()
        torch.save(best_model, f"./best_model.pth")

    print(
        f"train-epoch-loss:{np.mean(train_loss):.4f}",
        f"valid-epoch-loss:{np.mean(valid_loss):.4f}",
        f"train-beta4-score:{np.mean(train_mean_scores):.4f}",
        f"valid-beta4-score:{np.mean(valid_mean_scores):.4f}",
    )

    grand_train_loss.append(np.mean(train_loss))
    grand_valid_loss.append(np.mean(valid_loss))
    grand_train_score.append(np.mean(train_mean_scores))
    grand_valid_score.append(np.mean(valid_mean_scores))

Epoch 1/80 [Training]: 100%|██████████| 100/100 [01:10<00:00,  1.41it/s, loss=0.9160]
Epoch 1/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  6.42it/s, loss=0.8809]


train-epoch-loss:0.9160 valid-epoch-loss:0.8809 train-beta4-score:0.0092 valid-beta4-score:0.0022


Epoch 2/80 [Training]: 100%|██████████| 100/100 [01:06<00:00,  1.50it/s, loss=0.8422]
Epoch 2/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.69it/s, loss=0.8619]


train-epoch-loss:0.8422 valid-epoch-loss:0.8619 train-beta4-score:0.1071 valid-beta4-score:0.0801


Epoch 3/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.48it/s, loss=0.8023]
Epoch 3/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.52it/s, loss=0.8312]


train-epoch-loss:0.8023 valid-epoch-loss:0.8312 train-beta4-score:0.1657 valid-beta4-score:0.1170


Epoch 4/80 [Training]: 100%|██████████| 100/100 [01:08<00:00,  1.47it/s, loss=0.7827]
Epoch 4/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.39it/s, loss=0.7093]


train-epoch-loss:0.7827 valid-epoch-loss:0.7093 train-beta4-score:0.0977 valid-beta4-score:0.0731


Epoch 5/80 [Training]: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s, loss=0.7454]
Epoch 5/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.88it/s, loss=0.6896]


train-epoch-loss:0.7454 valid-epoch-loss:0.6896 train-beta4-score:0.2465 valid-beta4-score:0.2161


Epoch 6/80 [Training]: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s, loss=0.7405]
Epoch 6/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.06it/s, loss=0.6855]


train-epoch-loss:0.7405 valid-epoch-loss:0.6855 train-beta4-score:0.2784 valid-beta4-score:0.2624


Epoch 7/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.49it/s, loss=0.7138]
Epoch 7/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.13it/s, loss=0.5367]


train-epoch-loss:0.7138 valid-epoch-loss:0.5367 train-beta4-score:0.3549 valid-beta4-score:0.2958


Epoch 8/80 [Training]: 100%|██████████| 100/100 [01:11<00:00,  1.39it/s, loss=0.7156]
Epoch 8/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.01it/s, loss=0.6720]


train-epoch-loss:0.7156 valid-epoch-loss:0.6720 train-beta4-score:0.3060 valid-beta4-score:0.2602


Epoch 9/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.48it/s, loss=0.6764]
Epoch 9/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.01it/s, loss=0.7035]


train-epoch-loss:0.6764 valid-epoch-loss:0.7035 train-beta4-score:0.4804 valid-beta4-score:0.3588


Epoch 10/80 [Training]: 100%|██████████| 100/100 [01:08<00:00,  1.45it/s, loss=0.6789]
Epoch 10/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.97it/s, loss=0.5041]


train-epoch-loss:0.6789 valid-epoch-loss:0.5041 train-beta4-score:0.5911 valid-beta4-score:0.3985


Epoch 11/80 [Training]: 100%|██████████| 100/100 [01:11<00:00,  1.40it/s, loss=0.6680]
Epoch 11/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s, loss=0.5715]


train-epoch-loss:0.6680 valid-epoch-loss:0.5715 train-beta4-score:0.4823 valid-beta4-score:0.3009


Epoch 12/80 [Training]: 100%|██████████| 100/100 [01:14<00:00,  1.33it/s, loss=0.6641]
Epoch 12/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.86it/s, loss=0.6701]


train-epoch-loss:0.6641 valid-epoch-loss:0.6701 train-beta4-score:0.5058 valid-beta4-score:0.3655


Epoch 13/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.49it/s, loss=0.6392]
Epoch 13/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.86it/s, loss=0.6365]


train-epoch-loss:0.6392 valid-epoch-loss:0.6365 train-beta4-score:0.6129 valid-beta4-score:0.4204


Epoch 14/80 [Training]: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s, loss=0.6441]
Epoch 14/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s, loss=0.6438]


train-epoch-loss:0.6441 valid-epoch-loss:0.6438 train-beta4-score:0.6621 valid-beta4-score:0.4320


Epoch 15/80 [Training]: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s, loss=0.6188]
Epoch 15/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.84it/s, loss=0.5603]


train-epoch-loss:0.6188 valid-epoch-loss:0.5603 train-beta4-score:0.8093 valid-beta4-score:0.5262


Epoch 16/80 [Training]: 100%|██████████| 100/100 [01:11<00:00,  1.39it/s, loss=0.6406]
Epoch 16/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.81it/s, loss=0.7644]


train-epoch-loss:0.6406 valid-epoch-loss:0.7644 train-beta4-score:0.5701 valid-beta4-score:0.3745


Epoch 17/80 [Training]: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s, loss=0.6364]
Epoch 17/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.68it/s, loss=0.4622]


train-epoch-loss:0.6364 valid-epoch-loss:0.4622 train-beta4-score:0.5860 valid-beta4-score:0.3700


Epoch 18/80 [Training]: 100%|██████████| 100/100 [01:12<00:00,  1.39it/s, loss=0.6554]
Epoch 18/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.62it/s, loss=0.7391]


train-epoch-loss:0.6554 valid-epoch-loss:0.7391 train-beta4-score:0.4641 valid-beta4-score:0.3478


Epoch 19/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.47it/s, loss=0.6074]
Epoch 19/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.64it/s, loss=0.5927]


train-epoch-loss:0.6074 valid-epoch-loss:0.5927 train-beta4-score:0.7326 valid-beta4-score:0.4288


Epoch 20/80 [Training]: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s, loss=0.6125]
Epoch 20/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.76it/s, loss=0.5481]


train-epoch-loss:0.6125 valid-epoch-loss:0.5481 train-beta4-score:0.4579 valid-beta4-score:0.3051


Epoch 21/80 [Training]: 100%|██████████| 100/100 [01:10<00:00,  1.41it/s, loss=0.5926]
Epoch 21/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.77it/s, loss=0.4705]


train-epoch-loss:0.5926 valid-epoch-loss:0.4705 train-beta4-score:0.4207 valid-beta4-score:0.3150


Epoch 22/80 [Training]: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s, loss=0.5971]
Epoch 22/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.61it/s, loss=0.6476]


train-epoch-loss:0.5971 valid-epoch-loss:0.6476 train-beta4-score:0.7801 valid-beta4-score:0.4778


Epoch 23/80 [Training]: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s, loss=0.6171]
Epoch 23/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.62it/s, loss=0.4629]


train-epoch-loss:0.6171 valid-epoch-loss:0.4629 train-beta4-score:0.6167 valid-beta4-score:0.3765


Epoch 24/80 [Training]: 100%|██████████| 100/100 [01:11<00:00,  1.40it/s, loss=0.5996]
Epoch 24/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.64it/s, loss=0.7289]


train-epoch-loss:0.5996 valid-epoch-loss:0.7289 train-beta4-score:0.5043 valid-beta4-score:0.3421


Epoch 25/80 [Training]: 100%|██████████| 100/100 [01:07<00:00,  1.48it/s, loss=0.5698]
Epoch 25/80 [Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.78it/s, loss=0.6187]


train-epoch-loss:0.5698 valid-epoch-loss:0.6187 train-beta4-score:0.7397 valid-beta4-score:0.4542


Epoch 26/80 [Training]:  46%|████▌     | 46/100 [00:30<00:38,  1.40it/s, loss=0.5986]

In [None]:
# train_lossとvalid_lossのプロット

plt.plot(grand_train_loss, label="train_loss")
plt.plot(grand_valid_loss, label="valid_loss")
plt.legend()
plt.show()

In [None]:
# train_scoreとvalid_scoreのプロット
plt.plot(grand_train_score, label="train_score")
plt.plot(grand_valid_score, label="valid_score")
plt.legend()
plt.show()

In [None]:
random.random()