## Example of usage of the DataLoader

In [None]:
from cellgroup.data import InMemoryDataset
from cellgroup.configs import DataConfig
from cellgroup.data.datasets.harvard import SampleHarvard, ChannelHarvard, get_fnames
from cellgroup.data.utils import in_memory_collate_fn
from cellgroup.data.preprocessing import standardize

#### 1. Sequential patching

In [None]:
dset_config = DataConfig(
    samples=[SampleHarvard.A06],
    channels=[ChannelHarvard.Ch1, ChannelHarvard.Ch13],
    time_steps=(32, 42, 2),
    img_dim="2D",
    patch_size=(256, 256),
    patch_overlap=None,
)

In [None]:
dset = InMemoryDataset(
    data_dir="/group/jug/federico/data/Cellgroup",
    data_config=dset_config,
    get_fnames_fn=get_fnames,
)

In [None]:
dset.data.shape

In [None]:
dset.patches.shape

In [None]:
len(dset)

In [None]:
patch, coords = dset[1123]

In [None]:
print(patch.shape)
print(coords)

In [None]:
import matplotlib.pyplot as plt

_, axes = plt.subplots(10, 10, figsize=(30, 30))
for i, ax in enumerate(axes.flat):
    patch, _ = dset[i]
    ax.imshow(patch, cmap="viridis")
    ax.axis("off")

#### 2. Overlapped patching

In [None]:
dset_config = DataConfig(
    samples=[SampleHarvard.A06],
    channels=[ChannelHarvard.Ch1, ChannelHarvard.Ch13],
    time_steps=(32, 36, 2),
    img_dim="2D",
    patch_size=(256, 256),
    patch_overlap=(64, 64),
    batch_size=32,
    preprocessing_funcs=[standardize],
    dloader_kwargs={
        "num_workers": 0,
        "collate_fn": in_memory_collate_fn,
    },
)

In [None]:
dset = InMemoryDataset(
    data_dir="/group/jug/federico/data/Cellgroup",
    data_config=dset_config,
    get_fnames_fn=get_fnames,
)

In [None]:
dset.data.shape

In [None]:
dset.data.coords, dset.data.dims

In [None]:
dset.patches.shape

In [None]:
dset.patches.coords, dset.patches.dims

In [None]:
# import matplotlib.pyplot as plt

# _, axes = plt.subplots(5, 5, figsize=(20, 20))
# for i, ax in enumerate(axes.flat):
#     patch, _ = dset[i]
#     ax.imshow(patch, cmap="viridis")
#     ax.axis("off")

#### 3. Stitch Overlapped patches

Simulate stitching of segemented patches using a dataloader

In [None]:
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

from cellgroup.data.patching import stitch_patches
from cellgroup.data.utils import reorder_images

In [None]:
dloader = DataLoader(
    dset, 
    batch_size=dset_config.batch_size, 
    shuffle=False, 
    num_workers=dset_config.dloader_kwargs.get("num_workers"), 
    collate_fn=dset_config.dloader_kwargs.get("collate_fn")
)

In [None]:
patches = []
infos = []
for batch in tqdm(dloader):
    patches.append(batch[0])
    infos.extend(batch[1])
patches = np.concatenate(patches, axis=0)

In [None]:
imgs, imgs_info = stitch_patches(patches, infos)

In [None]:
len(imgs_info), len(imgs), imgs[0].shape

In [None]:
imgs_info[0]

In [None]:
img_arr = reorder_images(imgs, imgs_info)

In [None]:
type(img_arr), img_arr.shape, img_arr.coords, img_arr.dims

In [None]:
dset.data.shape, dset.data.coords

In [None]:
np.allclose(dset.data.values, img_arr.values)