In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch

import skimage as skm

from kedro.extras.datasets.pickle import PickleDataSet

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.pipelines.data_processing.nodes import prepare_input
from tagseg.data.acdc_dataset import AcdcDataSet
from tagseg.data.utils import SimulateTags

In [None]:
dataset = PickleDataSet(filepath='../data/03_primary/acdc_train.pt').load()

In [None]:
examples, labels = dataset.tensors

In [None]:
tagger = SimulateTags(
    label=dataset[0][1],
    myo_index=1
)

In [None]:
def tag(image: torch.tensor) -> torch.tensor:
    return tagger(((image - image.min()) / (image.max() - image.min())) * 255)

In [None]:
fig, ax = plt.subplots(2, 4, figsize=(8, 4))

for n, i in enumerate([12, 171, 1076, 75]):
    
    contours = skm.measure.find_contours(labels[i].numpy())

    ax[0, n].imshow(examples[i, 0], cmap='gray')
    ax[1, n].imshow(tag(examples[i, 0]), cmap='gray')

    for row in range(2):
        ax[row, n].get_xaxis().set_ticks([])
        ax[row, n].get_yaxis().set_ticks([])

        for contour in contours:
            ax[row, n].plot(*contour[:, ::-1].T, c='b')

    cx, cy = contours[1].mean(axis=0)
    padding = 80

    for x in ax:
        x[n].set_xlim(cx - padding, cx + padding)
        x[n].set_ylim(cy + padding, cy - padding)

    ax[0, 0].set_ylabel('cine')
    ax[1, 0].set_ylabel('tagged')

plt.tight_layout()
plt.savefig('../../figures/cine2tag-physics.pdf', dpi=300, bbox_inches='tight')
plt.show()