# Data splitting

## Usage

In [None]:
# from src.util.data import split_class_dataset

# split_class_dataset("datasets/tea-std/images", "temp/splits")

## Validation

In [None]:
# import glob

# real_tot = len(glob.glob("datasets/tea-std/images/*/*"))
# tot = 0
# for file in ["train", "val", "test"]:
#     with open(f"temp/splits/{file}.txt") as handler:
#         paths = handler.read().strip().split("\n")
#         classes = list(set([p.split("/")[-2] for p in paths]))
#         print(file, len(classes), len(paths))
#         tot += len(paths)
# assert real_tot <= tot, (real_tot, tot)

# ClassDataset

## Usage

In [None]:
from src.datasets import ClassDataset
import matplotlib.pyplot as plt
import random

ds = ClassDataset("datasets/tea-grade-v2")
id1 = random.randint(0, len(ds))
id2 = random.randint(0, len(ds))
sample1 = ds[id1]
sample2 = ds[id2]
lbl1, img1 = sample1["lbl"], sample1["img"]
lbl2, img2 = sample2["lbl"], sample2["img"]
print(f"Dataset length: {len(ds)}")
print(f"Sample shape: {img1.shape}")
print(f"Tea grade: {ds.lbl_cls_map[lbl1]}")
print(f"Dataset class-label map: {ds.cls_lbl_map}")
print(f"Dataset label-class map: {ds.lbl_cls_map}")

img1, img2 = img1.numpy(), img2.numpy()
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img1.transpose(1, 2, 0))
ax[0].set_title(ds.lbl_cls_map[lbl1])
ax[1].imshow(img2.transpose(1, 2, 0))
ax[1].set_title(ds.lbl_cls_map[lbl2])
plt.show()

## Validate

In [None]:
import torch
from src.datasets import ClassDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.constants import img_wh

batch_size = 16
ds_names = [
    "datasets/uc-mlr-leaf",
    "datasets/new-plant-disease",
    "datasets/plant-doc",
    "datasets/tea-std",
    "datasets/tea-grade-v2",
]
for root in ds_names:
    for split in ["val", "test", "train"]:
        ds = ClassDataset(root, split, img_wh)
        dl = DataLoader(ds, batch_size)
        for batch in tqdm(dl, desc=f"{root}: {split}"):
            imgs, lbls = batch["img"], batch["lbl"]
            assert imgs.shape[1:] == torch.Size([3, *img_wh[::-1]]), imgs.shape
            assert imgs.dtype == torch.float32
            assert lbls.shape[1:] == torch.Size([]), lbls.shape
            assert imgs.min() >= 0
            assert imgs.max() <= 1

## TeaClass with Augmentation collate_fn

### Usage

In [None]:
import numpy as np
from src.datasets import ClassDataset
from src.datasets.collate_fns import aug_collate_fn, Augmentor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from functools import partial

ds = ClassDataset("datasets/tea-grade-v2")
aug = Augmentor([224, 224])
batch_size = 16
aug_count = 8
dl = DataLoader(
    ds, batch_size, shuffle=True, collate_fn=partial(aug_collate_fn, aug=aug, aug_count=aug_count)
)
for batch in dl:
    lbls, imgs = batch["lbl"], batch["img"]
    print(f"DataLoager length: {len(dl)}")
    print(f"Batch images shape: {imgs.shape}")
    print(f"Batch labels shape: {lbls.shape}")

    imgs = imgs.numpy().transpose(0, 1, 3, 4, 2)
    row_count = 2
    col_count = 4
    fig, ax = plt.subplots(row_count, col_count, figsize=(col_count*3, row_count*3))
    img_ids = np.random.randint(0, batch_size, row_count)
    for row, img_id in enumerate(img_ids):
        aug_ids = np.random.randint(0, aug_count, col_count)
        for col, aug_id in enumerate(aug_ids):
            ax[row][col].imshow(imgs[img_id, aug_id])
            ax[row][col].set_title(f"{ds.lbl_cls_map[int(lbls[img_id][0])]}, img_id: {img_id}, aug_id: {aug_id}")
    plt.show()
    break

### Validate

In [None]:
import torch
from src.datasets import ClassDataset
from src.datasets.collate_fns import aug_collate_fn, Augmentor
from torch.utils.data import DataLoader
from functools import partial
from tqdm import tqdm
from src.constants import img_wh

aug = Augmentor([224,224])
batch_size = 16
aug_count = 8
ds_names = ["tea-std", "tea-grade-v2"]
for ds_name in ds_names:
    for split in ["train", "test", "val"]:
        ds = ClassDataset(f"datasets/{ds_name}", split, dataset=ds_name)
        dl = DataLoader(
            ds,
            batch_size,
            collate_fn=partial(aug_collate_fn, aug=aug, aug_count=aug_count),
            num_workers=4,
        )
        for batch in tqdm(dl, desc=f"{ds_name}: {split}"):
            lbls, imgs = batch["lbl"], batch["img"]
            assert lbls.dtype == torch.int64
            assert imgs.dtype == torch.float
            assert imgs.min() >= 0
            assert imgs.max() <= 1
            assert lbls.shape[1:] == torch.Size([aug_count])
            assert imgs.shape[1:] == torch.Size([aug_count, 3, *img_wh[::-1]])

# ConcatSet

## Usage

In [None]:
from src.datasets import ConcatSet
import random
import matplotlib.pyplot as plt


root = [
    "datasets/tea-grade-v2",
    "datasets/tea-std",
    "datasets/plant-doc",
]
conf = [
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": [300, 300]},
    },
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": [300, 300]},
    },
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": [300, 300]},
    },
]

ds = ConcatSet(root, "train", conf)

id1 = random.randint(0, len(ds))
id2 = random.randint(0, len(ds))
sample1 = ds[id1]
sample2 = ds[id2]
lbl1, img1 = sample1["lbl"], sample1["img"]
lbl2, img2 = sample2["lbl"], sample2["img"]
print(f"Dataset length: {len(ds)}")
print(f"Sample shape: {img1.shape}")

img1, img2 = img1.numpy(), img2.numpy()
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img1.transpose(1, 2, 0))
ax[0].set_title(lbl1)
ax[1].imshow(img2.transpose(1, 2, 0))
ax[1].set_title(lbl2)
plt.show()

## Validation

In [None]:
import torch
from src.datasets import ConcatSet
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.constants import img_wh


batch_size = 16
resize_shape = [300, 300]
root = [
    "datasets/tea-grade-v2",
    "datasets/tea-std",
    "datasets/plant-doc",
]
conf = [
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": resize_shape},
    },
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": resize_shape},
    },
    {
        "target": "src.datasets.ClassDataset",
        "reps": 1,
        "split_mix": {"train": ["train"], "val": ["val", "test"]},
        "params": {"resize_wh": resize_shape},
    },
]

for split in ["val", "test", "train"]:
    ds = ConcatSet(root, split, conf)
    dl = DataLoader(ds, batch_size)
    for batch in tqdm(dl, desc=f"{split}"):
        imgs, lbls = batch["img"], batch["lbl"]
        assert imgs.shape[1:] == torch.Size([3, *resize_shape[::-1]]), imgs.shape
        assert imgs.dtype == torch.float32
        assert lbls.shape[1:] == torch.Size([]), lbls.shape
        assert imgs.min() >= 0
        assert imgs.max() <= 1