In [2]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-19.0.1-cp312-cp312-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp312-cp312-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading multiprocess-0.70.16-py312-none-any.whl (146 kB)
Using cached pyarrow-19.0.1-cp312-cp312-win_amd64.whl (25.3 MB)
Using cached xxhash-3.5.0-cp312-cp312-win_amd64.whl (30 kB)
Installing collected packages: xxhash, pyarrow, dill, multiprocess, datasets
Successfully installed datas

In [None]:
import torchvision.transforms.functional as TF
from datasets import load_dataset
from torch.utils.data import DataLoader

# image processing function
def process_img(x):
    x = TF.resize(x, (1024, 1024))
    x = TF.to_tensor(x)
    return x

# item processing function
def process_batch(examples):
    examples["basecolor"] = [process_img(x) for x in examples["basecolor"]]
    return examples

# load the dataset in streaming mode
ds = load_dataset(
    "gvecchio/MatSynth", 
    streaming = True,
)

# remove unwanted columns
ds = ds.remove_columns(["diffuse", "specular", "displacement", "opacity", "blend_mask"])
# or keep only specified columns
ds = ds.select_columns(["metadata", "basecolor"])

# shuffle data
ds = ds.shuffle(buffer_size=100)

# filter data matching a specific criteria, e.g.: only CC0 materials
ds = ds.filter(lambda x: x.get("metadata", {}).get("license") == "CC0")
# filter out data from Deschaintre et al. 2018
ds = ds.filter(lambda x: x.get("metadata", {}).get("source") != "deschaintre_2020")

# Set up processing
ds = ds.map(process_batch, batched=True, batch_size=8)

# set format for usage in torch
ds = ds.with_format("torch")

# iterate over the dataset
for x in ds:
    print(x)
