In [1]:
import os
import zarr
import random
import json
import numpy as np
import pandas as pd
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch
import torchvision.transforms.functional as F
import random
import sys
from collections import defaultdict

sys.path.append("./src/")

from src.config import CFG
from src.dataloader import (
    read_zarr,
    read_info_json,
    scale_coordinates,
    create_dataset,
    create_segmentation_map,
    EziiDataset,
    drop_padding,
)
from src.network import UNet_2D, aug
from src.utils import save_images
from src.metric import score

sample_submission = pd.read_csv("../../inputs/sample_submission.csv")

In [19]:
valid_dataset = EziiDataset(
    # exp_names=CFG.valid_exp_names,
    exp_names=CFG.train_exp_names,
    base_dir="../../inputs/train/static",
    particles_name=CFG.particles_name,
    resolution=CFG.resolution,
    zarr_type=CFG.valid_zarr_types,
)

from tqdm import tqdm

# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

for row in tqdm(valid_loader):
    normalized_tomogram = row["normalized_tomogram"]
    break

[('TS_5_4', 'denoised'), ('TS_73_6', 'denoised'), ('TS_99_9', 'denoised'), ('TS_6_4', 'denoised'), ('TS_69_2', 'denoised')]


  0%|          | 0/5 [00:00<?, ?it/s]


In [20]:
model = UNet_2D().to("cuda")
model.eval()
model.load_state_dict(torch.load("./best_model.pth"))


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor([0.5, 32, 32, 32, 32, 32, 32]).to("cuda")
)
# criterion = DiceLoss()

best_model = None
best_loss = np.inf
batch_size = 4

valid_loss = []
valid_pred_tomogram = defaultdict(list)
valid_gt_tomogram = defaultdict(list)
model.eval()
tq = tqdm(range(len(valid_loader) * normalized_tomogram.shape[0]))
for data in valid_loader:
    exp_name = data["exp_name"][0]
    tomogram = data["normalized_tomogram"].to("cuda")
    segmentation_map = data["segmentation_map"].to("cuda").long()

    for i in range(tomogram.shape[1]):
        input_ = tomogram[:, i].unsqueeze(0)
        gt = segmentation_map[:, i]

        output = model(input_)
        loss = criterion(output, gt)

        valid_loss.append(loss.item())
        tq.set_description(f"Loss: {np.mean(valid_loss)}")
        tq.update(1)

        valid_pred_tomogram[exp_name].append(output.cpu().detach().numpy())
        valid_gt_tomogram[exp_name].append(gt.cpu().detach().numpy())
tq.close()

Loss: 0.22641571319621542: : 460it [00:20, 22.55it/s]                   


In [21]:
valid_pred_tomogram.keys()

dict_keys(['TS_5_4', 'TS_73_6', 'TS_99_9', 'TS_6_4', 'TS_69_2'])

In [22]:
import cc3d


def create_cls_pos(pred_tomogram):
    cls_pos = []
    Ascale_pos = []
    resolution_info = CFG.resolution2ratio

    for pred_cls in range(1, len(CFG.particles_name) + 1):
        cc = cc3d.connected_components(pred_tomogram == pred_cls)
        stats = cc3d.statistics(cc)

        for z, x, y in stats["centroids"]:
            Ascale_z = z * resolution_info[CFG.resolution] / resolution_info["A"]
            Ascale_x = x * resolution_info[CFG.resolution] / resolution_info["A"]
            Ascale_y = y * resolution_info[CFG.resolution] / resolution_info["A"]

            cls_pos.append([pred_cls, z, y, x])
            Ascale_pos.append([pred_cls, Ascale_z, Ascale_y, Ascale_x])

    return cls_pos, Ascale_pos


def create_df(pos, exp_name):
    results = []
    for cls, z, y, x in pos:
        results.append(
            {
                "experiment": exp_name,
                "particle_type": CFG.cls2particles[cls],
                "x": x,
                "y": y,
                "z": z,
            }
        )

    return pd.DataFrame(results)


# exp_name = "TS_86_3"
# exp_name = "TS_6_6"
exp_name = "TS_99_9"


# pred
pred_tomogram = np.array(valid_pred_tomogram[exp_name]).argmax(2).squeeze(1)
pred_tomogram = drop_padding(pred_tomogram, CFG.resolution)
pred_cls_pos, pred_Ascale_pos = create_cls_pos(pred_tomogram)
pred_df = create_df(pred_Ascale_pos, exp_name)

pred_df = pred_df.reset_index()

# gt
gt_tomogram = np.array(valid_gt_tomogram[exp_name]).squeeze(1)
gt_tomogram = drop_padding(gt_tomogram, CFG.resolution)
gt_cls_pos, gt_Ascale_pos = create_cls_pos(gt_tomogram)
gt_df = create_df(gt_Ascale_pos, exp_name)

gt_df = gt_df.reset_index()

In [23]:
pred_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_99_9,apo-ferritin,3139.702673,3140.820386,910.106360
1,1,TS_99_9,apo-ferritin,500.000000,3680.000000,60.000000
2,2,TS_99_9,apo-ferritin,2196.052632,3191.578947,145.526316
3,3,TS_99_9,apo-ferritin,1380.000000,320.000000,140.000000
4,4,TS_99_9,apo-ferritin,1863.333333,3340.000000,140.000000
...,...,...,...,...,...,...
1379,1379,TS_99_9,virus-like-particle,6143.157895,2890.877193,1417.543860
1380,1380,TS_99_9,virus-like-particle,4800.000000,1520.000000,1440.000000
1381,1381,TS_99_9,virus-like-particle,4840.000000,1540.000000,1440.000000
1382,1382,TS_99_9,virus-like-particle,5680.000000,1210.000000,1460.000000


In [24]:
gt_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_99_9,apo-ferritin,3139.965349,3140.118628,910.124131
1,1,TS_99_9,apo-ferritin,1740.000000,1280.000000,240.000000
2,2,TS_99_9,apo-ferritin,2040.000000,1220.000000,240.000000
3,3,TS_99_9,apo-ferritin,660.000000,5600.000000,260.000000
4,4,TS_99_9,apo-ferritin,1900.000000,1180.000000,260.000000
...,...,...,...,...,...,...
168,168,TS_99_9,virus-like-particle,3317.152209,5677.152209,797.152209
169,169,TS_99_9,virus-like-particle,5517.152209,4177.152209,837.152209
170,170,TS_99_9,virus-like-particle,2457.152209,2997.152209,877.152209
171,171,TS_99_9,virus-like-particle,4297.152209,2237.152209,937.152209


In [39]:
score(
    solution=pred_df,
    submission=gt_df,
    row_id_column_name="index",
    distance_multiplier=1,
    beta=4,
)

0.05733982066288482

In [40]:
pred_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_99_9,apo-ferritin,3139.702673,3140.820386,910.106360
1,1,TS_99_9,apo-ferritin,500.000000,3680.000000,60.000000
2,2,TS_99_9,apo-ferritin,2196.052632,3191.578947,145.526316
3,3,TS_99_9,apo-ferritin,1380.000000,320.000000,140.000000
4,4,TS_99_9,apo-ferritin,1863.333333,3340.000000,140.000000
...,...,...,...,...,...,...
1379,1379,TS_99_9,virus-like-particle,6143.157895,2890.877193,1417.543860
1380,1380,TS_99_9,virus-like-particle,4800.000000,1520.000000,1440.000000
1381,1381,TS_99_9,virus-like-particle,4840.000000,1540.000000,1440.000000
1382,1382,TS_99_9,virus-like-particle,5680.000000,1210.000000,1460.000000


In [41]:
gt_df

Unnamed: 0,index,experiment,particle_type,x,y,z
0,0,TS_99_9,apo-ferritin,3139.965349,3140.118628,910.124131
1,1,TS_99_9,apo-ferritin,1740.000000,1280.000000,240.000000
2,2,TS_99_9,apo-ferritin,2040.000000,1220.000000,240.000000
3,3,TS_99_9,apo-ferritin,660.000000,5600.000000,260.000000
4,4,TS_99_9,apo-ferritin,1900.000000,1180.000000,260.000000
...,...,...,...,...,...,...
168,168,TS_99_9,virus-like-particle,3317.152209,5677.152209,797.152209
169,169,TS_99_9,virus-like-particle,5517.152209,4177.152209,837.152209
170,170,TS_99_9,virus-like-particle,2457.152209,2997.152209,877.152209
171,171,TS_99_9,virus-like-particle,4297.152209,2237.152209,937.152209
