In [1]:
import zarr

zarr.__version__

'2.18.2'

In [1]:
import os
import zarr
import timm
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import sys
import torch

# import torchvision.transforms.functional as F
import random

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import Unet3D
from src.utils import save_images, PadToSize
from src.metric import (
    score,
    create_cls_pos,
    create_cls_pos_sikii,
    create_df,
    SegmentationLoss,
    DiceLoss,
)
from metric import visualize_epoch_results
from src.utils import save_images
from src.metric import score, create_cls_pos, create_cls_pos_sikii, create_df
from src.inference import inference, inference2pos

sample_submission = pd.read_csv("../../inputs/sample_submission.csv")

In [2]:
def create_gt_df(base_dir, exp_names):
    result_df = None
    particle_names = CFG.particles_name

    for exp_name in exp_names:
        for particle in particle_names:
            np_corrds = read_info_json(
                base_dir=base_dir, exp_name=exp_name, particle_name=particle
            )  # (n, 3)
            # 各行にexp_nameとparticle_name追加
            particle_df = pd.DataFrame(np_corrds, columns=["z", "y", "x"])
            particle_df["experiment"] = exp_name
            particle_df["particle_type"] = particle

            if result_df is None:
                result_df = particle_df
            else:
                result_df = pd.concat([result_df, particle_df], axis=0).reset_index(
                    drop=True
                )

    result_df = result_df.reset_index()
    result_df = result_df[["index", "experiment", "particle_type", "x", "y", "z"]]

    return result_df


gt_df = create_gt_df("../../inputs/train/overlay/ExperimentRuns/", CFG.train_exp_names)
gt_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_4,apo-ferritin,3045.036742,919.139280,421.270403
1,1,TS_4,apo-ferritin,2969.078552,1027.114255,440.085721
2,2,TS_4,apo-ferritin,2839.792769,1069.080767,425.839468
3,3,TS_4,apo-ferritin,2875.180486,1077.907940,298.254286
4,4,TS_4,apo-ferritin,2765.950544,1019.336833,322.072039
...,...,...,...,...,...,...
14628,14628,TS_6_6,virus-like-particle,2609.876000,4569.876000,1169.759000
14629,14629,TS_6_6,virus-like-particle,2213.287000,4135.017000,1286.851000
14630,14630,TS_6_6,virus-like-particle,3303.905000,5697.825000,789.744000
14631,14631,TS_6_6,virus-like-particle,1008.748000,5949.213000,1077.303000


In [3]:
encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder).to("cuda")
model.load_state_dict(torch.load("./best_model.pth"))

<All keys matched successfully>

In [4]:
def inference2pos(pred_segmask, exp_name, sikii_dict):
    import cc3d

    cls_pos = []
    Ascale_pos = []
    res2ratio = CFG.resolution2ratio

    for pred_cls in range(1, len(CFG.particles_name) + 1):
        sikii = sikii_dict[CFG.cls2particles[pred_cls]]
        # print(pred_segmask[pred_cls].shape)
        cc, P = cc3d.connected_components(pred_segmask[pred_cls] > sikii, return_N=True)
        # cc, P = cc3d.connected_components(pred_segmask == pred_cls, return_N=True)
        stats = cc3d.statistics(cc)

        for z, y, x in stats["centroids"][1:]:
            Ascale_z = z * res2ratio[CFG.resolution] / res2ratio["A"]
            Ascale_x = x * res2ratio[CFG.resolution] / res2ratio["A"]
            Ascale_y = y * res2ratio[CFG.resolution] / res2ratio["A"]

            cls_pos.append([pred_cls, z, y, x])
            Ascale_pos.append([pred_cls, Ascale_z, Ascale_y, Ascale_x])

    pred_original_df = create_df(Ascale_pos, exp_name)

    return pred_original_df

In [5]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

num_classes = len(CFG.particles_name)  # クラス数
colors = plt.cm.tab10(
    np.arange(len(CFG.particles_name))
)  # "tab10" カラーマップから色を取得

# ListedColormap を作成
class_colormap = ListedColormap(colors)


def plot_with_colormap(data, title, original_tomogram):
    masked_data = np.ma.masked_where(data <= 0, data)  # クラス0をマスク
    plt.imshow(original_tomogram, cmap="gray")
    im = plt.imshow(masked_data, cmap=class_colormap)
    plt.title(title)
    plt.axis("off")
    return im


def imshow_result(pred, gt, original, index):
    # plt.figure(figsize=(20, 5))
    ax = plt.subplot(1, 3, 1)
    plot_with_colormap(
        pred[index],
        "Train-Prediction",
        original[index],
    )
    ax = plt.subplot(1, 3, 2)
    plot_with_colormap(gt[index], "Gt", original[index])

    ax = plt.subplot(1, 3, 3)
    plt.imshow(original[index], cmap="gray")
    plt.axis("off")

    plt.show()

In [6]:
gt_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_4,apo-ferritin,3045.036742,919.139280,421.270403
1,1,TS_4,apo-ferritin,2969.078552,1027.114255,440.085721
2,2,TS_4,apo-ferritin,2839.792769,1069.080767,425.839468
3,3,TS_4,apo-ferritin,2875.180486,1077.907940,298.254286
4,4,TS_4,apo-ferritin,2765.950544,1019.336833,322.072039
...,...,...,...,...,...,...
14628,14628,TS_6_6,virus-like-particle,2609.876000,4569.876000,1169.759000
14629,14629,TS_6_6,virus-like-particle,2213.287000,4135.017000,1286.851000
14630,14630,TS_6_6,virus-like-particle,3303.905000,5697.825000,789.744000
14631,14631,TS_6_6,virus-like-particle,1008.748000,5949.213000,1077.303000


In [7]:
def create_gt_df(base_dir, exp_names):
    result_df = None
    particle_names = CFG.particles_name

    for exp_name in exp_names:
        for particle in particle_names:
            np_corrds = read_info_json(
                base_dir=base_dir, exp_name=exp_name, particle_name=particle
            )  # (n, 3)
            # 各行にexp_nameとparticle_name追加
            particle_df = pd.DataFrame(np_corrds, columns=["z", "y", "x"])
            particle_df["experiment"] = exp_name
            particle_df["particle_type"] = particle

            if result_df is None:
                result_df = particle_df
            else:
                result_df = pd.concat([result_df, particle_df], axis=0).reset_index(
                    drop=True
                )

    result_df = result_df.reset_index()
    result_df = result_df[["index", "experiment", "particle_type", "x", "y", "z"]]

    return result_df

In [8]:
exp_name = CFG.valid_exp_names[0]
# exp_name = CFG.train_exp_names[0]
exp_name

'TS_5_4'

In [9]:
import timm

encoder = timm.create_model(
    model_name=CFG.model_name,
    pretrained=True,
    in_chans=3,
    num_classes=0,
    global_pool="",
    features_only=True,
)
model = Unet3D(encoder=encoder, num_domains=5).to("cuda")
# model.load_state_dict(torch.load("./pretrained_model.pth"))
model.load_state_dict(torch.load("./best_model.pth"))

# inferenced_array = inference(model, exp_name, train=False)
# 0.7303962244998289

<All keys matched successfully>

In [10]:
inferenced_array, n_tomogram, segmentation_map = inference(model, exp_name, train=False)

for constant in np.linspace(0.2, 0.9, 20):
    initial_sikii = {
        "apo-ferritin": constant,
        "beta-amylase": constant,
        "beta-galactosidase": constant,
        "ribosome": constant,
        "thyroglobulin": constant,
        "virus-like-particle": constant,
    }

    pred_original_df = inference2pos(
        pred_segmask=inferenced_array,
        exp_name=exp_name,
        sikii_dict=initial_sikii,
    )
    gt_df = create_gt_df(
        base_dir="../../inputs/train/overlay/ExperimentRuns/", exp_names=[exp_name]
    )

    s = score(
        pred_original_df,
        gt_df,
        row_id_column_name="index",
        distance_multiplier=1.0,
        beta=4,
    )
    print(constant, s)

0.2 0.43237133146151574
0.2368421052631579 0.5057183814274865
0.2736842105263158 0.5252491208342525
0.31052631578947365 0.5712345207139473
0.34736842105263155 0.6005693481346098
0.38421052631578945 0.591997131155094
0.42105263157894735 0.6241245889981034
0.45789473684210524 0.7300464068318507
0.49473684210526314 0.7459445618543482
0.531578947368421 0.7462344074448036
0.5684210526315789 0.8269033717689862
0.6052631578947368 0.8417531154095659
0.6421052631578947 0.8159518888388317
0.6789473684210525 0.9004345723347977
0.7157894736842105 0.8301338246858786
0.7526315789473683 0.6455796000467252
0.7894736842105263 0.6190859694533984
0.8263157894736841 0.4033785937400395
0.8631578947368421 0.39226075470591415
0.9 0.30808757399371867


In [11]:
"""

# 0
0.2 0.4378329374584124
0.2368421052631579 0.5191351374526408
0.2736842105263158 0.5624028914763121
0.31052631578947365 0.59879074721956
0.34736842105263155 0.5710671423750229
0.38421052631578945 0.5822544588975088
0.42105263157894735 0.6215029895463589
0.45789473684210524 0.7273128111105172
0.49473684210526314 0.7980716302890555
0.531578947368421 0.810720319371696
0.5684210526315789 0.8371835839001497
0.6052631578947368 0.8844482505852502
0.6421052631578947 0.8956330850645984
0.6789473684210525 0.8678871497232438
0.7157894736842105 0.8449238083076072
0.7526315789473683 0.41169530959444384
0.7894736842105263 0.3999160462733188
0.8263157894736841 0.39624270977541404
0.8631578947368421 0.3815673084524206
0.9 0.24427736006683373


# 4
0.2 0.5679466223272989
0.2368421052631579 0.5995843372438189
0.2736842105263158 0.6994273618959725
0.31052631578947365 0.6873817147539326
0.34736842105263155 0.7732137649875339
0.38421052631578945 0.8083663044186515
0.42105263157894735 0.7625602588443811
0.45789473684210524 0.7760536196438625
0.49473684210526314 0.8031105441822127
0.531578947368421 0.8502003203785099
0.5684210526315789 0.8738341581052514
0.6052631578947368 0.8317116054334563
0.6421052631578947 0.7548386333209339
0.6789473684210525 0.7522277874616529
0.7157894736842105 0.6223504410486821
0.7526315789473683 0.6084178345334992
0.7894736842105263 0.6319147195238717
0.8263157894736841 0.5156331157302112
0.8631578947368421 0.41285716422700375
0.9 0.360662956827096
"""

'\n\n# 0\n0.2 0.4378329374584124\n0.2368421052631579 0.5191351374526408\n0.2736842105263158 0.5624028914763121\n0.31052631578947365 0.59879074721956\n0.34736842105263155 0.5710671423750229\n0.38421052631578945 0.5822544588975088\n0.42105263157894735 0.6215029895463589\n0.45789473684210524 0.7273128111105172\n0.49473684210526314 0.7980716302890555\n0.531578947368421 0.810720319371696\n0.5684210526315789 0.8371835839001497\n0.6052631578947368 0.8844482505852502\n0.6421052631578947 0.8956330850645984\n0.6789473684210525 0.8678871497232438\n0.7157894736842105 0.8449238083076072\n0.7526315789473683 0.41169530959444384\n0.7894736842105263 0.3999160462733188\n0.8263157894736841 0.39624270977541404\n0.8631578947368421 0.3815673084524206\n0.9 0.24427736006683373\n\n\n# 4\n0.2 0.5679466223272989\n0.2368421052631579 0.5995843372438189\n0.2736842105263158 0.6994273618959725\n0.31052631578947365 0.6873817147539326\n0.34736842105263155 0.7732137649875339\n0.38421052631578945 0.8083663044186515\n0.42