In [1]:
import os
import zarr
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import UNet_2D, aug
from src.utils import save_images
from src.metric import score, create_cls_pos, create_cls_pos_sikii, create_df
from metric import visualize_epoch_results

In [2]:
train_dataset = EziiDataset(
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.train_zarr_types,
)

valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
)

from tqdm import tqdm

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

for data in tqdm(train_dataset):
    normalized_tomogram = data["normalized_tomogram"]
    segmentation_map = data["segmentation_map"]
    break

# normalized_tomogram =
normalized_tomogram.shape[0]

[('TS_5_4', 'denoised'), ('TS_5_4', 'ctfdeconvolved'), ('TS_5_4', 'wbp'), ('TS_5_4', 'isonetcorrected'), ('TS_73_6', 'denoised'), ('TS_73_6', 'ctfdeconvolved'), ('TS_73_6', 'wbp'), ('TS_73_6', 'isonetcorrected'), ('TS_99_9', 'denoised'), ('TS_99_9', 'ctfdeconvolved'), ('TS_99_9', 'wbp'), ('TS_99_9', 'isonetcorrected'), ('TS_6_4', 'denoised'), ('TS_6_4', 'ctfdeconvolved'), ('TS_6_4', 'wbp'), ('TS_6_4', 'isonetcorrected'), ('TS_69_2', 'denoised'), ('TS_69_2', 'ctfdeconvolved'), ('TS_69_2', 'wbp'), ('TS_69_2', 'isonetcorrected')]
[('TS_86_3', 'denoised'), ('TS_6_6', 'denoised')]


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]


92

In [3]:
import torch
import torchvision.transforms.functional as F
import random

In [None]:
model = UNet_2D().to("cuda")
model.eval()


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor([0.5, 32, 32, 32, 32, 32, 32]).to("cuda")
)
# criterion = DiceLoss()

best_model = None
best_score = 0
batch_size = 4

for epoch in range(100):
    train_loss = []
    valid_loss = []
    train_pred_tomogram = defaultdict(list)
    train_gt_tomogram = defaultdict(list)
    valid_pred_tomogram = defaultdict(list)
    valid_gt_tomogram = defaultdict(list)
    model.train()
    tq = tqdm(range(len(train_loader) * normalized_tomogram.shape[0] // batch_size))
    for data in train_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"]
        segmentation_map = data["segmentation_map"].long()

        for i in range(batch_size, tomogram.shape[1], batch_size):
            optimizer.zero_grad()
            from_, to_ = 0, tomogram.shape[1]
            random_index = random.sample(range(from_, to_), batch_size)
            input_ = tomogram[:, random_index]
            input_ = input_.permute(1, 0, 2, 3)  # (batch_size, 1, 160, 160)
            gt = segmentation_map[:, random_index].squeeze()  # (batch_size, 160, 160)

            input_, gt = aug(input_, gt)

            input_ = input_.to("cuda")
            gt = gt.to("cuda")
            output = model(input_)
            loss = criterion(output, gt)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            tq.set_description(f"Train-Epoch: {epoch}, Loss: {np.mean(train_loss)}")
            tq.update(1)

            for i in range(batch_size):
                train_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
                train_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    ############################################# valid #############################################

    model.eval()
    tq = tqdm(range(len(valid_loader) * normalized_tomogram.shape[0]))
    for data in valid_loader:
        exp_name = data["exp_name"][0]
        tomogram = data["normalized_tomogram"].to("cuda")
        segmentation_map = data["segmentation_map"].to("cuda").long()

        for i in range(tomogram.shape[1]):
            input_ = tomogram[:, i].unsqueeze(0)
            gt = segmentation_map[:, i]

            output = model(input_)
            loss = criterion(output, gt)

            valid_loss.append(loss.item())
            tq.set_description(f"Valid-Epoch: {epoch}, Loss: {np.mean(valid_loss)}")
            tq.update(1)

            valid_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
            valid_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
    tq.close()

    valid_score_ = visualize_epoch_results(
        valid_pred_tomogram,
        valid_gt_tomogram,
        sikii_dict=CFG.initial_sikii,
    )

    print(f"EPOCH: {epoch}, VALID_SCORE: {valid_score_}")

    if valid_score_ > best_score:
        best_score = valid_score_
        best_model = model
        torch.save(model.state_dict(), f"best_model.pth")

    # 可視化
    index = 50

    # plt.figure(figsize=(10, 5))

    # ax = plt.subplot(1, 4, 1)
    # ax.imshow(train_pred_tomogram[exp_name][index].argmax(0))
    # ax.set_title("Train-Prediction")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 2)
    # ax.imshow(train_gt_tomogram[exp_name][index])
    # ax.set_title("Train-Ground Truth")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 3)
    # ax.imshow(valid_pred_tomogram[exp_name][index].argmax(1).squeeze(0))
    # ax.set_title("Valid-Prediction")
    # ax.axis("off")

    # ax = plt.subplot(1, 4, 4)
    # ax.imshow(valid_gt_tomogram[exp_name][index].squeeze(0))
    # ax.set_title("Valid-Ground Truth")
    # ax.axis("off")

    # plt.tight_layout()

    # plt.show()

    # save_images(
    #     train_gt_tomogram=train_gt_tomogram,
    #     train_pred_tomogram=train_pred_tomogram,
    #     valid_gt_tomogram=valid_gt_tomogram,
    #     valid_pred_tomogram=valid_pred_tomogram,
    #     save_dir="images",
    #     epoch=epoch,
    # )

Train-Epoch: 0, Loss: 1.1286558195271275:  96%|█████████▌| 440/460 [02:02<00:05,  3.59it/s]
Valid-Epoch: 0, Loss: 0.8437045598807542: 100%|██████████| 184/184 [00:08<00:00, 21.44it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 0, VALID_SCORE: 0.16331919092049935


Train-Epoch: 1, Loss: 0.641396359523589:  96%|█████████▌| 440/460 [01:50<00:05,  3.98it/s] 
Valid-Epoch: 1, Loss: 0.5539978383432912: 100%|██████████| 184/184 [00:08<00:00, 22.35it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 1, VALID_SCORE: 0.1621112643299146


Train-Epoch: 2, Loss: 0.5470620284703644:  96%|█████████▌| 440/460 [01:51<00:05,  3.96it/s]
Valid-Epoch: 2, Loss: 0.4182375419844428: 100%|██████████| 184/184 [00:08<00:00, 21.05it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 2, VALID_SCORE: 0.1664375376599305


Train-Epoch: 3, Loss: 0.5265457897362384:  96%|█████████▌| 440/460 [01:51<00:05,  3.95it/s] 
Valid-Epoch: 3, Loss: 0.604196591080045: 100%|██████████| 184/184 [00:08<00:00, 22.13it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 3, VALID_SCORE: 0.16518094305247394


Train-Epoch: 4, Loss: 0.45657147964970635:  96%|█████████▌| 440/460 [01:54<00:05,  3.83it/s]
Valid-Epoch: 4, Loss: 0.41514947876820096: 100%|██████████| 184/184 [00:08<00:00, 22.96it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 4, VALID_SCORE: 0.17849448915719646


Train-Epoch: 5, Loss: 0.44601014510474424:  96%|█████████▌| 440/460 [01:51<00:05,  3.95it/s]
Valid-Epoch: 5, Loss: 0.6221511074463311: 100%|██████████| 184/184 [00:08<00:00, 22.61it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 5, VALID_SCORE: 0.1709761434394012


Train-Epoch: 6, Loss: 0.46159604887732053:  96%|█████████▌| 440/460 [01:48<00:04,  4.05it/s]
Valid-Epoch: 6, Loss: 0.5635336078703403: 100%|██████████| 184/184 [00:07<00:00, 23.01it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 6, VALID_SCORE: 0.1658937818753654


Train-Epoch: 7, Loss: 0.4013065948743712:  96%|█████████▌| 440/460 [01:50<00:05,  3.97it/s] 
Valid-Epoch: 7, Loss: 0.6943034791974756: 100%|██████████| 184/184 [00:08<00:00, 22.96it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 7, VALID_SCORE: 0.17770172897105793


Train-Epoch: 8, Loss: 0.3825694895095446:  96%|█████████▌| 440/460 [01:50<00:05,  4.00it/s] 
Valid-Epoch: 8, Loss: 0.5914158243281038: 100%|██████████| 184/184 [00:08<00:00, 22.57it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 8, VALID_SCORE: 0.16979571784650638


Train-Epoch: 9, Loss: 0.37806123508648437:  96%|█████████▌| 440/460 [01:53<00:05,  3.88it/s]
Valid-Epoch: 9, Loss: 0.677232059204708: 100%|██████████| 184/184 [00:08<00:00, 22.38it/s] 


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 9, VALID_SCORE: 0.06738158163935096


Train-Epoch: 10, Loss: 0.3311801333149726:  96%|█████████▌| 440/460 [01:49<00:04,  4.01it/s] 
Valid-Epoch: 10, Loss: 0.7223331125544222: 100%|██████████| 184/184 [00:07<00:00, 23.49it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 10, VALID_SCORE: 0.09327547986804231


Train-Epoch: 11, Loss: 0.3750706537880681:  96%|█████████▌| 440/460 [01:51<00:05,  3.96it/s] 
Valid-Epoch: 11, Loss: 0.6328563534745785: 100%|██████████| 184/184 [00:07<00:00, 23.22it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 11, VALID_SCORE: 0.04041313073330822


Train-Epoch: 12, Loss: 0.3526741882955486:  96%|█████████▌| 440/460 [01:49<00:04,  4.02it/s] 
Valid-Epoch: 12, Loss: 0.6614465270363523: 100%|██████████| 184/184 [00:07<00:00, 23.18it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 12, VALID_SCORE: 0.03666948575811785


Train-Epoch: 13, Loss: 0.3413356122560799:  96%|█████████▌| 440/460 [01:50<00:05,  3.99it/s] 
Valid-Epoch: 13, Loss: 0.7111055482401634: 100%|██████████| 184/184 [00:08<00:00, 22.59it/s]


####################### valid-experiments: TS_86_3 #######################
####################### valid-experiments: TS_6_6 #######################
EPOCH: 13, VALID_SCORE: 0.03025550539505384


Train-Epoch: 14, Loss: 0.30562740363025775:  94%|█████████▍| 433/460 [01:40<00:05,  4.92it/s]