In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import TensorDataset
from tqdm import tqdm

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data.utils import load_nii
from tagseg.data.mnm_dataset import MnmDataSet

### Information from CSV

In [None]:
raw_path = Path('../data/01_raw/OpenDataset/Training/Labeled')

In [None]:
df = pd.read_csv('../data/01_raw/OpenDataset/211230_M&Ms_Dataset_information_diagnosis_opendataset.csv', index_col=0)

In [None]:
df.head()

### Loading Dataset with custom loader

In [None]:
mnm = MnmDataSet(
    filepath='../data/03_primary/mnm_train.pt', 
    load_args=dict(
        filepath_raw = '../data/01_raw/OpenDataset/Training/Labeled',
        only_myo=True
    )
)

In [None]:
!rm ../data/03_primary/mnm_train.pt

In [None]:
dataset = mnm.load()
mnm.save(dataset)

In [None]:
len(dataset)

In [None]:
c_slice, c_phase = 6, 9

fig, ax = plt.subplots(1, 3, figsize=(20, 7))

ax[0].imshow(images[..., c_slice, c_phase], cmap='gray')
ax[1].imshow(labels[..., c_slice, c_phase], cmap='viridis')

masked = np.ma.masked_where(labels[..., c_slice, c_phase] == 0, labels[..., c_slice, c_phase])
ax[2].imshow(images[..., c_slice, c_phase], cmap='gray')
ax[2].imshow(masked, cmap='jet', interpolation='nearest', alpha=0.3)

In [None]:
M, N = 20, 5
fig, ax = plt.subplots(M, N, figsize=(20, 100))

for i in range(M * N):
    m, n = i % M, i // M
    ax[m, n].imshow(dataset[i][0][0].cpu(), cmap='gray')
    
    mask = dataset[i][1].cpu()
    mask = np.ma.masked_where(mask == 0, mask)
    ax[m, n].imshow(mask, cmap='Reds', alpha=0.8)
    
    ax[m, n].axis('off')