In [1]:
import os
import zarr
import random
import json
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch
import torchvision.transforms.functional as F
import random
import sys
from collections import defaultdict

warnings.filterwarnings("ignore")
sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import UNet_2D, aug
from src.utils import save_images
from src.metric import score, create_cls_pos, create_cls_pos_sikii, create_df

sample_submission = pd.read_csv("../../inputs/sample_submission.csv")

In [2]:
valid_dataset = EziiDataset(
    exp_names=CFG.valid_exp_names,
    # exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
    train=True,
)

from tqdm import tqdm

# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

for row in tqdm(valid_loader):
    normalized_tomogram = row["normalized_tomogram"]
    break

[('TS_86_3', 'denoised'), ('TS_6_6', 'denoised')]


  0%|          | 0/2 [00:00<?, ?it/s]


In [3]:
class PadToSize(nn.Module):
    def __init__(self, resolution):
        super().__init__()
        if resolution == "0":
            self.size = 640
        elif resolution == "1":
            self.size = 320
        elif resolution == "2":
            self.size = 160

    def forward(self, x):
        return F.pad(x, (0, 0, self.size - x.shape[-1], self.size - x.shape[-2]))

In [4]:
# model = UNet_2D().to("cuda")
model = UNet_2D().to("cuda").half()
model.eval()
model.load_state_dict(torch.load("./best_model.pth"))


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor([0.5, 32, 32, 32, 32, 32, 32]).to("cuda")
)
# criterion = DiceLoss()

best_model = None
best_loss = np.inf
batch_size = 4

valid_loss = []
valid_pred_tomogram = defaultdict(list)
valid_gt_tomogram = defaultdict(list)
model.eval()
tq = tqdm(range(len(valid_loader) * normalized_tomogram.shape[0]))
for data in valid_loader:
    exp_name = data["exp_name"][0]
    tomogram = data["normalized_tomogram"].to("cuda")
    segmentation_map = data["segmentation_map"].to("cuda").long()

    for i in range(tomogram.shape[1]):
        # input_ = tomogram[:, i].unsqueeze(0)
        input_ = tomogram[:, i].unsqueeze(0).to("cuda", dtype=torch.float16)
        gt = segmentation_map[:, i]  # .to("cuda", dtype=torch.float16)

        input_ = PadToSize(CFG.resolution)(input_)
        gt = PadToSize(CFG.resolution)(gt)
        output = model(input_)
        output = nn.functional.softmax(output, dim=1)
        # loss = criterion(output, gt)

        # valid_loss.append(loss.item())
        tq.set_description(f"Loss: {np.mean(valid_loss)}")
        tq.update(1)

        output = drop_padding(output, CFG.resolution)

        valid_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
        valid_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
tq.close()

Loss: nan: : 184it [00:05, 32.62it/s]                   


In [5]:
def create_gt_df(base_dir, exp_names):
    result_df = None
    particle_names = CFG.particles_name

    for exp_name in exp_names:
        for particle in particle_names:
            np_corrds = read_info_json(
                base_dir=base_dir, exp_name=exp_name, particle_name=particle
            )  # (n, 3)
            # 各行にexp_nameとparticle_name追加
            particle_df = pd.DataFrame(np_corrds, columns=["z", "y", "x"])
            particle_df["experiment"] = exp_name
            particle_df["particle_type"] = particle

            if result_df is None:
                result_df = particle_df
            else:
                result_df = pd.concat([result_df, particle_df], axis=0).reset_index(
                    drop=True
                )

    result_df = result_df.reset_index()  # index	experiment	particle_type	x	y	z
    result_df = result_df[["index", "experiment", "particle_type", "x", "y", "z"]]

    return result_df


gt_df = create_gt_df("../../inputs/train/overlay/ExperimentRuns/", CFG.valid_exp_names)
gt_df = gt_df[gt_df["particle_type"] != "beta-amylase"].reset_index(drop=True)
gt_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_86_3,apo-ferritin,3870.343,4952.714,1261.600
1,1,TS_86_3,apo-ferritin,4130.897,5422.292,501.860
2,2,TS_86_3,apo-ferritin,2735.000,4668.447,520.291
3,3,TS_86_3,apo-ferritin,2649.615,4690.615,600.923
4,4,TS_86_3,apo-ferritin,2665.353,4810.641,612.019
...,...,...,...,...,...,...
340,363,TS_6_6,virus-like-particle,2609.876,4569.876,1169.759
341,364,TS_6_6,virus-like-particle,2213.287,4135.017,1286.851
342,365,TS_6_6,virus-like-particle,3303.905,5697.825,789.744
343,366,TS_6_6,virus-like-particle,1008.748,5949.213,1077.303


In [15]:
# pred_df = pd.read_csv("../../inputs/train_submission.csv")


def calc_score(initial_sikii):
    all_pred_df = None

    for exp_name in CFG.valid_exp_names:
        pred_tomogram = valid_pred_tomogram[exp_name]
        pred_tomogram = np.array(pred_tomogram)  # (92, 1, 7, 315, 315)
        pred_tomogram = pred_tomogram.squeeze(1)  # (92, 7, 315, 315)

        pred_cls_pos, pred_Ascale_pos = create_cls_pos_sikii(
            pred_tomogram, sikii_dict=initial_sikii
        )
        pred_df = create_df(pred_Ascale_pos, exp_name)
        # pred_df = create_df(pred_cls_pos, exp_name)

        if all_pred_df is None:
            all_pred_df = pred_df
        else:
            all_pred_df = pd.concat([all_pred_df, pred_df], axis=0).reset_index(
                drop=True
            )

    pred_df = all_pred_df[all_pred_df["particle_type"] != "beta-amylase"]
    pred_df = pred_df.drop_duplicates(subset=["x", "y", "z"], keep="first").reset_index(
        drop=True
    )

    pred_df = pred_df.reset_index()

    score_ = score(
        pred_df, gt_df, row_id_column_name="index", distance_multiplier=1, beta=4
    )

    return score_, pred_df

In [17]:
constant = 0.44393939393939397

initial_sikii = {
    "apo-ferritin": constant,
    "beta-amylase": constant,
    "beta-galactosidase": constant,
    "ribosome": constant,
    "thyroglobulin": constant,
    "virus-like-particle": constant,
}

score_, pred_df = calc_score(initial_sikii)
score_

0.4541495695373164

In [18]:
pred_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_86_3,apo-ferritin,3140.006683,3140.019503,910.009707
1,1,TS_86_3,apo-ferritin,5676.470588,1262.352941,438.039216
2,2,TS_86_3,apo-ferritin,4815.555556,2829.629630,430.370370
3,3,TS_86_3,apo-ferritin,3025.806452,2852.903226,443.870968
4,4,TS_86_3,apo-ferritin,1738.787879,5111.515152,444.242424
...,...,...,...,...,...,...
241,241,TS_6_6,virus-like-particle,408.645833,2708.576389,983.368056
242,242,TS_6_6,virus-like-particle,1002.240664,5900.000000,1079.419087
243,243,TS_6_6,virus-like-particle,2606.800000,4567.200000,1150.160000
244,244,TS_6_6,virus-like-particle,3548.362720,976.171285,1182.972292


In [8]:
best_sikii = 0
best_score = -np.inf

for sikii in np.linspace(0.25, 0.45, 100):
    initial_sikii = {
        "apo-ferritin": sikii,
        "beta-amylase": sikii,
        "beta-galactosidase": sikii,
        "ribosome": sikii,
        "thyroglobulin": sikii,
        "virus-like-particle": sikii,
    }
    score_, _ = calc_score(initial_sikii)
    if score_ > best_score:
        best_score = score_
        best_sikii = sikii
    print(sikii, score_)

0.25 0.38682866945959354
0.25202020202020203 0.38668278588332117
0.25404040404040407 0.3877645132122935
0.25606060606060604 0.38097700585065314
0.2580808080808081 0.38274652537848597
0.2601010101010101 0.3845955510430099
0.26212121212121214 0.3849902670502468
0.2641414141414141 0.38291954650216914
0.26616161616161615 0.39081793102254697
0.2681818181818182 0.3905656707217853
0.2702020202020202 0.38150743515229174
0.2722222222222222 0.3718219607317903
0.27424242424242423 0.37757185474381905
0.27626262626262627 0.38286677450962214
0.2782828282828283 0.38180047091829333
0.2803030303030303 0.36423550285581513
0.2823232323232323 0.36538713303792963
0.28434343434343434 0.3678559213996585
0.2863636363636364 0.3689485619578358
0.2883838383838384 0.37005586872572266
0.2904040404040404 0.3721058019022086
0.2924242424242424 0.3732759719210121
0.29444444444444445 0.3769052944495056
0.2964646464646465 0.3787540465239657
0.29848484848484846 0.38135226340363954
0.3005050505050505 0.38113506357871113
0

In [14]:
best_sikii, best_score

# (0.44393939393939397, 0.4541495695373164)

(0.44393939393939397, 0.4541495695373164)