In [1]:
import os
import zarr
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import UNet_2D, aug
from src.utils import save_images
from src.metric import score, create_cls_pos, create_cls_pos_sikii, create_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
    train=True,
)

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    train=True,
)

from tqdm import tqdm

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

for data in tqdm(train_dataset):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

# normalized_tomogram =
normalized_tomogram.shape[0]

[('TS_5_4', 'denoised'), ('TS_73_6', 'denoised'), ('TS_99_9', 'denoised'), ('TS_6_4', 'denoised'), ('TS_69_2', 'denoised')]
[('TS_86_3', 'denoised'), ('TS_6_6', 'denoised')]


  0%|          | 0/5 [00:00<?, ?it/s]


46

In [3]:
import torch
import torchvision.transforms.functional as F
import random

In [None]:
model = UNet_2D().to("cuda")
model.eval()


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor([0.5, 32, 32, 32, 32, 32, 32]).to("cuda")
)
# criterion = DiceLoss()

best_model = None
best_score = 0
batch_size = 4

for epoch in range(100):
    train_loss = []
    valid_loss = []
    train_pred_tomogram = defaultdict(list)
    train_gt_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)
    model.train()
    tq = tqdm(range(len(train_loader) * normalized_tomogram.shape[0] // batch_size))
    for data in train_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"]
        segmentation_map = data["segmentation_map"].long()

        for i in range(batch_size, tomogram.shape[1], batch_size):
            optimizer.zero_grad()
            from_, to_ = 0, tomogram.shape[1]
            random_index = random.sample(range(from_, to_), batch_size)
            input_ = tomogram[:, random_index]
            input_ = input_.permute(1, 0, 2, 3)  # (batch_size, 1, 160, 160)
            gt = segmentation_map[:, random_index].squeeze()  # (batch_size, 160, 160)

            input_, gt = aug(input_, gt)

            input_ = input_.to("cuda")
            gt = gt.to("cuda")
            output = model(input_)
            loss = criterion(output, gt)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            tq.set_description(f"Train-Epoch: {epoch}, Loss: {np.mean(train_loss)}")
            tq.update(1)

            for i in range(batch_size):
                train_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
                train_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    ############################################# valid #############################################

    model.eval()
    tq = tqdm(range(len(valid_loader) * normalized_tomogram.shape[0]))
    for data in valid_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"].to("cuda")
        segmentation_map = data["segmentation_map"].to("cuda").long()

        for i in range(tomogram.shape[1]):
            input_ = tomogram[:, i].unsqueeze(0)
            gt = segmentation_map[:, i]

            output = model(input_)
            loss = criterion(output, gt)

            valid_loss.append(loss.item())
            tq.set_description(f"Valid-Epoch: {epoch}, Loss: {np.mean(valid_loss)}")
            tq.update(1)

            valid_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
            valid_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    valid_score_ = visualize_epoch_results(
        valid_pred_tomogram,
        valid_gt_tomogram,
        sikii_dict=CFG.initial_sikii,
    )

    print(f"EPOCH: {epoch}, VALID_SCORE: {valid_score_}")

    if valid_score_ > best_score:
        best_score = valid_score_
        best_model = model
        torch.save(model.state_dict(), f"best_model.pth")

    # 可視化
    index = 50

    # plt.figure(figsize=(10, 5))

    # ax = plt.subplot(1, 4, 1)
    # ax.imshow(train_pred_tomogram[exp_name][index].argmax(0))
    # ax.set_title("Train-Prediction")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 2)
    # ax.imshow(train_gt_tomogram[exp_name][index])
    # ax.set_title("Train-Ground Truth")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 3)
    # ax.imshow(valid_pred_tomogram[exp_name][index].argmax(1).squeeze(0))
    # ax.set_title("Valid-Prediction")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 4)
    # ax.imshow(valid_gt_tomogram[exp_name][index].squeeze(0))
    # ax.set_title("Valid-Ground Truth")
    # ax.axis("off")

    # plt.tight_layout()

    # plt.show()

    # save_images(
    #     train_gt_tomogram=train_gt_tomogram,
    #     train_pred_tomogram=train_pred_tomogram,
    #     valid_gt_tomogram=valid_gt_tomogram,
    #     valid_pred_tomogram=valid_pred_tomogram,
    #     save_dir="images",
    #     epoch=epoch,
    # )

Train-Epoch: 0, Loss: 1.8500074711712924:  96%|█████████▋| 55/57 [00:07<00:00,  6.94it/s]
Valid-Epoch: 0, Loss: 1.681821531575659: 100%|██████████| 92/92 [00:03<00:00, 23.64it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 0, VALID_SCORE: 0.010057250222467411


Train-Epoch: 1, Loss: 1.5110007762908935:  96%|█████████▋| 55/57 [00:06<00:00,  8.62it/s]
Valid-Epoch: 1, Loss: 1.408421454222306: 100%|██████████| 92/92 [00:03<00:00, 25.14it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 1, VALID_SCORE: 0.010704970871791893


Train-Epoch: 2, Loss: 1.2854537855495105:  96%|█████████▋| 55/57 [00:05<00:00,  9.19it/s]
Valid-Epoch: 2, Loss: 1.159705377791239: 100%|██████████| 92/92 [00:03<00:00, 25.93it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 2, VALID_SCORE: 0.018876793840390316


Train-Epoch: 3, Loss: 1.1246337619694797:  96%|█████████▋| 55/57 [00:06<00:00,  9.15it/s]
Valid-Epoch: 3, Loss: 0.9241614613844: 100%|██████████| 92/92 [00:03<00:00, 26.44it/s]   


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 3, VALID_SCORE: 0.030901457656294748


Train-Epoch: 4, Loss: 0.963408742167733:  96%|█████████▋| 55/57 [00:06<00:00,  9.00it/s] 
Valid-Epoch: 4, Loss: 0.7327719619092734: 100%|██████████| 92/92 [00:03<00:00, 25.82it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 4, VALID_SCORE: 0.06678770519912888


Train-Epoch: 5, Loss: 0.8501280535351147:  96%|█████████▋| 55/57 [00:06<00:00,  8.71it/s]
Valid-Epoch: 5, Loss: 0.6418150775134563: 100%|██████████| 92/92 [00:03<00:00, 27.48it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 5, VALID_SCORE: 0.1162964659298001


Train-Epoch: 6, Loss: 0.7548423160206188:  96%|█████████▋| 55/57 [00:06<00:00,  7.91it/s]
Valid-Epoch: 6, Loss: 0.5604337444124015: 100%|██████████| 92/92 [00:03<00:00, 23.02it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 6, VALID_SCORE: 0.10382237270252768


Train-Epoch: 7, Loss: 0.6806444623253562:  96%|█████████▋| 55/57 [00:06<00:00,  8.34it/s]
Valid-Epoch: 7, Loss: 0.5069515248355658: 100%|██████████| 92/92 [00:03<00:00, 25.76it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 7, VALID_SCORE: 0.038775917900470885


Train-Epoch: 8, Loss: 0.7145792207934639:  96%|█████████▋| 55/57 [00:06<00:00,  8.83it/s]
Valid-Epoch: 8, Loss: 0.4606352863104447: 100%|██████████| 92/92 [00:03<00:00, 27.94it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 8, VALID_SCORE: 0.016944818729151948


Train-Epoch: 9, Loss: 0.6518902047113939:  96%|█████████▋| 55/57 [00:07<00:00,  7.35it/s]
Valid-Epoch: 9, Loss: 0.5282839426527852: 100%|██████████| 92/92 [00:04<00:00, 21.52it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 9, VALID_SCORE: 0.024555538425099235


Train-Epoch: 10, Loss: 0.5948602137742219:  95%|█████████▍| 54/57 [00:07<00:00,  7.68it/s]