In [3]:
from omegaconf import OmegaConf
from databases import load_data
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import pickle
import gzip
from pathlib import Path

config = OmegaConf.load("./configs/default_config_linearprobe50.yaml")
config.datasets_path = './data' # where you store datasets locally 

for dataset in [ 
                'uc-merced-land-use-dataset',   
                 'flowers-102', 
                 'caltech101',
                 'stanford_cars', 
                 'fgvc-aircraft-2013b',
                 'oxford-iiit-pet'
                   ]:
                #]: #'dtd', 'eurosat' ,

    print(dataset)
    
    config.dataset = dataset
    train_dataset, val_dataset, test_dataset, N = load_data(config, common_corruption=False)
    _, _, common_test_dataset, N = load_data(config, common_corruption=True)

    output_path = "/home/mheuillet/Desktop/processed_dataset/{}_processed.pkl.gz".format(config.dataset)
    with gzip.open(output_path, 'wb') as f:
        pickle.dump({
            "train": train_dataset,
            "val": val_dataset,
            "test": test_dataset,
            "test_common":common_test_dataset,
            "N": N
        }, f)

    print(f"✅ Saved to: {output_path}")

uc-merced-land-use-dataset
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/uc-merced-land-use-dataset_processed.pkl.gz
flowers-102
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/flowers-102_processed.pkl.gz
caltech101
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/caltech101_processed.pkl.gz
stanford_cars
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/stanford_cars_processed.pkl.gz
fgvc-aircraft-2013b
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/fgvc-aircraft-2013b_processed.pkl.gz
oxford-iiit-pet
✅ Saved to: /home/mheuillet/Desktop/processed_dataset/oxford-iiit-pet_processed.pkl.gz


In [1]:
import os
import tarfile
import zstandard as zstd
from pathlib import Path
from torchvision.datasets import ImageFolder
from torchvision import transforms
from tempfile import TemporaryDirectory

def load_processed_dataset(archive_path, transform=None):
    """
    Loads a dataset archive (.tar.zst), extracts it, and returns a dict of ImageFolder datasets.
    Args:
        archive_path (str or Path): Path to the .tar.zst archive
        transform (torchvision.transforms): Transform to apply to all splits (optional)
    Returns:
        dict: {split_name: ImageFolder dataset}
    """
    archive_path = Path(archive_path)

    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    temp_dir = TemporaryDirectory()
    temp_path = Path(temp_dir.name)

    # Decompress .tar.zst
    with open(archive_path, 'rb') as f_in:
        dctx = zstd.ZstdDecompressor()
        with dctx.stream_reader(f_in) as zst_stream:
            with tarfile.open(fileobj=zst_stream, mode='r|') as tar:
                tar.extractall(path=temp_path)

    # Load datasets for each split
    datasets = {}
    for split in ["train", "val", "test", "test_common"]:
        split_path = temp_path / split
        if split_path.exists():
            datasets[split] = ImageFolder(split_path, transform=transform)

    return datasets, temp_dir  # temp_dir must be kept alive while using datasets


archive_path = "/home/mheuillet/Desktop/uc-merced-land-use-dataset_processed.tar.zst"
datasets, tmp_dir = load_processed_dataset(archive_path)

# Example: load a batch
from torch.utils.data import DataLoader
loader = DataLoader(datasets["train"], batch_size=16, shuffle=True)

for imgs, labels in loader:
    print(imgs.shape, labels)
    break

# When you're done:
tmp_dir.cleanup()


ModuleNotFoundError: No module named 'torchvision'

In [None]:
### don't pay attention to this cell

# EuroSAT: seulement 8 runs on complete et tous les autres OOM
# Caltech: 26 runs ok, 
# Stanford Cars: 34 ok
# Aircraft: 26 ok
# DTD : 26 ok
# Oxford pet: 26 ok
# Flowers: 26 ok

stanford_cars = 6922 + 1222 + 8041 
caltech101 = 5899 + 1042 + 1736 
dtd = 1880 + 1880 + 1880 
eurosat = 18360 + 3240 + 5400 
fgvcaircraft2013b =  3334 + 3333 + 3333 
flowers102 = 1020 + 1020 + 6149 
oxfordiiitpet = 3128 + 552 + 3669 

print(eurosat, caltech101, dtd, fgvcaircraft2013b, flowers102, oxfordiiitpet, stanford_cars)

from PIL import Image
import torchvision.datasets as datasets

# Initialize the Caltech101 dataset without transformations
caltech101_sample = datasets.Caltech101(root='./data', download=False)
print(len(caltech101_sample))
for i in range( len(caltech101_sample) ) :
    image, label = caltech101_sample[i]

    # Check the image mode
    # print(f"Image Mode: {image.mode}")  # Expected: 'RGB'
    if image.mode != 'RGB':
        print('issue')