In [None]:
from checkpoints import CHECKPOINT_DIR
from figures import FIGURES_DIR

import torch

from hubmap.visualization import visualize_result
import matplotlib.pyplot as plt
import scienceplots as _
import pandas as pd
import seaborn as sns

In [None]:
plt.style.use(["science", "nature"])

In [None]:
ckpt1 = torch.load("/home/jflxb/Documents/lmu/HuBMAP/checkpoints/TransResUNet/wide_resnet50_2_x256.pt")
ckpt2 = torch.load("/home/jflxb/Documents/lmu/HuBMAP/checkpoints/TransResUNet/resnext101_32x8d_x512.pt")

# ckpt_resnet_152_x256 = torch.load("/home/jflxb/Documents/lmu/HuBMAP/checkpoints/TransResUNet/resnet152_x256_channel_weighted.pt")
# ckpt_wide_resnet_101_2_x512 = torch.load("/home/jflxb/Documents/lmu/HuBMAP/checkpoints/TransResUNet/wide_resnet101_2_x512_channel_weighted.pt")

In [None]:
training_loss_history_wide_resnet50_2_x256 = ckpt1["training_loss_history"]
validation_loss_history_wide_resnet50_2_x256 = ckpt1["validation_loss_history"]

training_loss_history_resnext101_32x8_x512 = ckpt2["training_loss_history"]
validation_loss_history_resnext101_32x8_x512 = ckpt2["validation_loss_history"]

In [None]:
def prepare_data(data):
    d = [(i, e) for i, elems in enumerate(data) for e in elems]
    df = pd.DataFrame(d, columns=["epoch", "value"])
    return df

In [None]:
training_loss_history_resnet_152_x256_data = prepare_data(training_loss_history_wide_resnet50_2_x256)
validation_loss_history_resnet_152_x256_data = prepare_data(validation_loss_history_wide_resnet50_2_x256)

training_loss_history_wide_resnet_101_2_x512_data = prepare_data(training_loss_history_resnext101_32x8_x512)
validation_loss_history_wide_resnet_101_2_x512_data = prepare_data(validation_loss_history_resnext101_32x8_x512)

In [None]:
pallette = sns.color_palette("Set2")

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=1)
# axs.grid()

# sns.lineplot(
#     training_loss_history_resnet_152_x256_data,
#     x="epoch",
#     y="value",
#     ax=axs,
#     linestyle="solid",
#     label="Training",
#     color=pallette[0],
# )
sns.lineplot(
    validation_loss_history_resnet_152_x256_data, 
    x="epoch", 
    y="value", 
    ax=axs, 
    # linestyle="dashed", 
    label="\\textbf{TransResU-Net-4x256} (Wide ResNet-50-2)", 
    # color=pallette[0],
)

# sns.lineplot(
#     training_loss_history_resnet_152_x256_data,
#     x="epoch",
#     y="value",
#     ax=axs,
#     linestyle="solid",
#     label="Training",
#     color=pallette[1],
# )
sns.lineplot(
    validation_loss_history_wide_resnet_101_2_x512_data, 
    x="epoch", 
    y="value", 
    ax=axs, 
    # linestyle="dashed", 
    label="\\textbf{TransResU-Net-4x512} (ResNeXt-101 32x8d)", 
    # color=pallette[1],
)

axs.set_xlabel("Epoch")
axs.set_ylabel("Loss")

In [None]:
fig.savefig("./trans_res_u-net_best_model_comparison.svg")

In [None]:
from hubmap.dataset import TrainDataset
import hubmap.dataset.transforms as T
from hubmap.data import DATA_DIR
import numpy as np

In [None]:
compose_256 = T.Compose([T.ToTensor(), T.Resize((256, 256))])
tset_256 = TrainDataset(DATA_DIR, transform=compose_256, with_background=True)

compose_512 = T.Compose([T.ToTensor(), T.Resize((512, 512))])
tset_512 = TrainDataset(DATA_DIR, transform=compose_512, with_background=True)

In [None]:
bv_pixels_256 = []
gl_pixels_256 = []
uns_pixels_256 = []
bg_pixels_256 = []

for _, target in tset_256:
    bv_pixels_256.append((target[0, :, :]).sum())
    gl_pixels_256.append((target[1, :, :]).sum())
    uns_pixels_256.append((target[2, :, :]).sum())
    bg_pixels_256.append((target[3, :, :]).sum())
    
bv_pixels_256 = np.array(bv_pixels_256)
gl_pixels_256 = np.array(gl_pixels_256)
uns_pixels_256 = np.array(uns_pixels_256)
bg_pixels_256 = np.array(bg_pixels_256)

In [None]:
total_pixels_per_mask_256 = 256 * 256

In [None]:
bv_per_image_ratio_256 = bv_pixels_256 / total_pixels_per_mask_256
gl_per_image_ratio_256 = gl_pixels_256 / total_pixels_per_mask_256
uns_per_image_ratio_256 = uns_pixels_256 / total_pixels_per_mask_256
bg_per_image_ratio_256 = bg_pixels_256 / total_pixels_per_mask_256

In [None]:
bv_per_image_ratio_256_mean = np.mean(bv_per_image_ratio_256)
gl_per_image_ratio_256_mean = np.mean(gl_per_image_ratio_256)
uns_per_image_ratio_256_mean = np.mean(uns_per_image_ratio_256)
bg_per_image_ratio_256_mean = np.mean(bg_per_image_ratio_256)

In [None]:
print("bv_per_image_ratio_256_mean: ", bv_per_image_ratio_256_mean)
print("gl_per_image_ratio_256_mean: ", gl_per_image_ratio_256_mean)
print("uns_per_image_ratio_256_mean: ", uns_per_image_ratio_256_mean)
print("bg_per_image_ratio_256_mean: ", bg_per_image_ratio_256_mean)

In [None]:
bv_per_image_ratio_256_mean + gl_per_image_ratio_256_mean + uns_per_image_ratio_256_mean + bg_per_image_ratio_256_mean

In [None]:
bv_per_image_weight = 1 - bv_per_image_ratio_256_mean
gl_per_image_weight = 1 - gl_per_image_ratio_256_mean
uns_per_image_weight = 1 - uns_per_image_ratio_256_mean
bg_per_image_weight = 1 - bg_per_image_ratio_256_mean

In [None]:
print("bv_per_image_weight: ", bv_per_image_weight)
print("gl_per_image_weight: ", gl_per_image_weight)
print("uns_per_image_weight: ", uns_per_image_weight)
print("bg_per_image_weight: ", bg_per_image_weight)

In [None]:
bv_per_image_weight + gl_per_image_weight + uns_per_image_weight + bg_per_image_weight

In [None]:
normalizer = 4 / (bv_per_image_weight + gl_per_image_weight + uns_per_image_weight + bg_per_image_weight)

In [None]:
bv_per_image_weight_normed = normalizer * bv_per_image_weight
gl_per_image_weight_normed = normalizer * gl_per_image_weight
uns_per_image_weight_normed = normalizer * uns_per_image_weight
bg_per_image_weight_normed = normalizer * bg_per_image_weight


print("bv_per_image_weight_normed: ", bv_per_image_weight_normed)
print("gl_per_image_weight_normed: ", gl_per_image_weight_normed)
print("uns_per_image_weight_normed: ", uns_per_image_weight_normed)
print("bg_per_image_weight_normed: ", bg_per_image_weight_normed)

In [None]:
bv_pixels_512 = []
gl_pixels_512 = []
uns_pixels_512 = []
bg_pixels_512 = []

for _, target in tset_512:
    bv_pixels_512.append((target[0, :, :]).sum())
    gl_pixels_512.append((target[1, :, :]).sum())
    uns_pixels_512.append((target[2, :, :]).sum())
    bg_pixels_512.append((target[3, :, :]).sum())
    
bv_pixels_512 = np.array(bv_pixels_512)
gl_pixels_512 = np.array(gl_pixels_512)
uns_pixels_512 = np.array(uns_pixels_512)
bg_pixels_512 = np.array(bg_pixels_512)

In [None]:
total_pixels_per_mask_512 = 512 * 512

In [None]:
bv_per_image_ratio_512 = bv_pixels_512 / total_pixels_per_mask_512
gl_per_image_ratio_512 = gl_pixels_512 / total_pixels_per_mask_512
uns_per_image_ratio_512 = uns_pixels_512 / total_pixels_per_mask_512
bg_per_image_ratio_512 = bg_pixels_512 / total_pixels_per_mask_512

bv_per_image_ratio_512_mean = np.mean(bv_per_image_ratio_512)
gl_per_image_ratio_512_mean = np.mean(gl_per_image_ratio_512)
uns_per_image_ratio_512_mean = np.mean(uns_per_image_ratio_512)
bg_per_image_ratio_512_mean = np.mean(bg_per_image_ratio_512)

In [None]:
print("bv_per_image_ratio_512_mean: ", bv_per_image_ratio_512_mean)
print("gl_per_image_ratio_512_mean: ", gl_per_image_ratio_512_mean)
print("uns_per_image_ratio_512_mean: ", uns_per_image_ratio_512_mean)
print("bg_per_image_ratio_512_mean: ", bg_per_image_ratio_512_mean)

In [None]:
bv_per_image_ratio_512_mean + gl_per_image_ratio_512_mean + uns_per_image_ratio_512_mean + bg_per_image_ratio_512_mean