In [None]:
from datamodule import FabricDataModule
import albumentations as A
from albumentations.pytorch import ToTensorV2
from copy import deepcopy
from tqdm import tqdm

# define olp-dataset root
root = "../olp-dataset/"

# FabricDataModule

In order to generate overarching datasets which cover > 1 FabricDataset, we inherit from the `LightningDataModule` class and let it hold a `ConcatDataset` that contains a single `FabricDataset` per selected fabric.

In [None]:
# Begin by setting up all transforms to be passed to the individual FabricDatasets
load_transform = A.Compose(
    [A.Resize(896, 896)],
    bbox_params=A.BboxParams(
        format="coco", label_fields=["instance_labels"]
    )
)
final_transform = [ToTensorV2()]

final_eval_transform = []
final_eval_transform.extend(deepcopy(final_transform))
final_eval_transform = A.Compose(
    final_eval_transform,
    bbox_params=A.BboxParams(
        format="coco", label_fields=["instance_labels"]
    )
)
final_train_transform = [
    A.CropNonEmptyMaskIfExists(224, 224)
]

final_train_transform.extend(deepcopy(final_transform))
final_train_transform = A.Compose(final_train_transform,
                                  bbox_params=A.BboxParams(
                                      format="coco", label_fields=["instance_labels"]
                                  )
                                 )

The `LightningDatamodule` possesses the same args as `FabricDataset` (as those are just forwarded).
Additionally, we can select a range of fabrics via `textiles`, whether to `invert` the selection or not, and all other data-composition/sampling related stuff, such as the `collate_fn`, oversampling of the anomalies etc.

**NOTE: If you want to use `LightningDataModule` for object detection/instance segmentation, you need to write a new collate_fn, refer https://github.com/pytorch/vision/issues/2624#issuecomment-681811444. The `FabricDataset` does provide all necessary information though on a per-sample basis**

In [None]:
datamodule = FabricDataModule(
    root=root,
    cache=False,
    textiles=list(range(1, 39)),
    transform=load_transform,
    uncached_train_transform=final_train_transform,
    uncached_eval_transform=final_eval_transform,
    load_mode="both")

datamodule.setup()

The loaders can then be used as you would use any `DataLoader` from pytorch, i.e.

In [None]:
for _ in tqdm(datamodule.train_dataloader()):
    pass
    # do the training loop
for _ in tqdm(datamodule.val_dataloader()):
    pass
    # do the validation loop
for _ in tqdm(datamodule.test_dataloader()):
    pass
    # do the testing loop

Alternatively, they can of course be used with pytorch-lightning, which was used by is in the original manuscript