In [2]:
import sys
sys.path.append('../../../Fedot.Industrial')

import os
import numpy as np
from torch.utils.data import Subset
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from core.architecture.datasets.splitters import get_dataset_mean_std, split_data, undersampling, dataset_info

DATASETS_ROOT = '/media/n31v/data/datasets'

## MNIST
mean=(0.13066,), std=(0.3081,)

In [13]:
from torchvision.datasets import MNIST

mnist_ds = MNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)
mean, std = get_dataset_mean_std(mnist_ds)
dataset_info(mnist_ds)
print(f'{mean=}, {std=}')

100%|██████████| 60000/60000 [00:06<00:00, 8974.01it/s]
100%|██████████| 60000/60000 [00:04<00:00, 14236.07it/s]

Class 0 contains 5923 samples.
Class 1 contains 6742 samples.
Class 2 contains 5958 samples.
Class 3 contains 6131 samples.
Class 4 contains 5842 samples.
Class 5 contains 5421 samples.
Class 6 contains 5918 samples.
Class 7 contains 6265 samples.
Class 8 contains 5851 samples.
Class 9 contains 5949 samples.
mean = (0.13066047797803165,), std = (0.3081078048756658,)





In [None]:
mnist_ds = undersampling(mnist_ds, n=5000)

folds = []
for i in range(5):
    f1, f2 = split_data(mnist_ds, 2)
    folds.append(np.array([mnist_ds.indices[f1], mnist_ds.indices[f2]]))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'MNIST', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'MNIST', 'folds.npy'))

mnist_ds = MNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=mnist_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=mnist_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## FashionMNIST
mean=(0.286,), std=(0.353,)

In [None]:
from torchvision.datasets import FashionMNIST

fmnist_ds = FashionMNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
    download=True,
)
mean, std = get_dataset_mean_std(fmnist_ds)
dataset_info(fmnist_ds)
print(f'{mean=}, {std=}')

In [None]:
fmnist_ds = undersampling(fmnist_ds, n=5000)

folds = []
for i in range(5):
    f1, f2 = split_data(fmnist_ds, 2)
    folds.append(np.array([fmnist_ds.indices[f1], fmnist_ds.indices[f2]]))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'FashionMNIST', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'FashionMNIST', 'folds.npy'))

mnist_ds = FashionMNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=mnist_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=mnist_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## CIFAR10
mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)

In [16]:
from torchvision.datasets import CIFAR10

cifar10_ds = CIFAR10(
    root=os.path.join(DATASETS_ROOT, 'CIFAR10'),
    transform=ToTensor(),
)
mean, std = get_dataset_mean_std(cifar10_ds)
dataset_info(cifar10_ds)
print(f'{mean=}, {std=}')

100%|██████████| 50000/50000 [00:06<00:00, 7721.50it/s]
100%|██████████| 50000/50000 [00:04<00:00, 11767.78it/s]

Class 0 contains 5000 samples.
Class 1 contains 5000 samples.
Class 2 contains 5000 samples.
Class 3 contains 5000 samples.
Class 4 contains 5000 samples.
Class 5 contains 5000 samples.
Class 6 contains 5000 samples.
Class 7 contains 5000 samples.
Class 8 contains 5000 samples.
Class 9 contains 5000 samples.
mean = (0.4913996927399561, 0.4821584222899936, 0.4465309280202538), std = (0.24703223297351337, 0.2434851287896555, 0.26158784042441807)





In [None]:
folds = []
for i in range(5):
    folds.append(np.array(split_data(cifar10_ds, 2)))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'CIFAR10', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'CIFAR10', 'folds.npy'))

cifar10_ds = CIFAR10(
    root=os.path.join(DATASETS_ROOT, 'CIFAR10'),
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=cifar10_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=cifar10_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## ImageFolder

In [5]:
from typing import Optional
from torchvision.datasets import ImageFolder

def check_dataset(dataset: str):
    folds = np.load(os.path.join(DATASETS_ROOT, dataset, 'folds.npy'))
    ds = ImageFolder(
        root=os.path.join(DATASETS_ROOT, dataset),
        transform=ToTensor(),
    )

    for i in range(5):
        fold1 = Subset(dataset=ds, indices=folds[i, 0, :])
        dataset_info(fold1, verbose=True)
        fold2 = Subset(dataset=ds, indices=folds[i, 1, :])
        dataset_info(fold2, verbose=True)

def prepare_dataset(dataset: str, check: bool = True, n: Optional[int] = None):
    ds = ImageFolder(
        root=os.path.join(DATASETS_ROOT, dataset),
        transform=ToTensor(),
    )
    mean, std = get_dataset_mean_std(ds)
    classes = dataset_info(ds)
    print('------------------------------------------------------------------')
    print('dataset info:')
    print('------------------------------------------------------------------')
    idx_to_class = {v: k for k, v in ds.class_to_idx.items()}
    for k, v in classes.items():
        print(f"Class {k} {idx_to_class[k]} contains {v} samples.")
    print(f'{mean=}, {std=}')
    print('------------------------------------------------------------------')
    print('undersamling...')
    print('------------------------------------------------------------------')
    if n is None:
        n = min(classes.values())
        n = n if n % 2 == 0 else n - 1
    ds = undersampling(ds, n=n)
    mean, std = get_dataset_mean_std(ds)
    print(f'{mean=}, {std=}')
    folds = []
    for i in range(5):
        f1, f2 = split_data(ds, 2)
        folds.append(np.array([ds.indices[f1], ds.indices[f2]]))
    folds = np.array(folds)
    np.save(os.path.join(DATASETS_ROOT, dataset, 'folds'), folds)
    print(folds.shape)
    if check:
        print('------------------------------------------------------------------')
        print('checking dataset...')
        print('------------------------------------------------------------------')
        check_dataset(dataset)

## Land-Use_Scene_Classification
mean=(0.458, 0.467, 0.437), std=(0.289, 0.281, 0.270)

In [6]:
prepare_dataset(os.path.join(DATASETS_ROOT, 'Land-Use_Scene_Classification/images'), n=200)

100%|██████████| 10500/10500 [00:20<00:00, 509.58it/s]
100%|██████████| 10500/10500 [00:19<00:00, 539.53it/s]


------------------------------------------------------------------
dataset info:
------------------------------------------------------------------
Class 0 agricultural contains 500 samples.
Class 1 airplane contains 500 samples.
Class 2 baseballdiamond contains 500 samples.
Class 3 beach contains 500 samples.
Class 4 buildings contains 500 samples.
Class 5 chaparral contains 500 samples.
Class 6 denseresidential contains 500 samples.
Class 7 forest contains 500 samples.
Class 8 freeway contains 500 samples.
Class 9 golfcourse contains 500 samples.
Class 10 harbor contains 500 samples.
Class 11 intersection contains 500 samples.
Class 12 mediumresidential contains 500 samples.
Class 13 mobilehomepark contains 500 samples.
Class 14 overpass contains 500 samples.
Class 15 parkinglot contains 500 samples.
Class 16 river contains 500 samples.
Class 17 runway contains 500 samples.
Class 18 sparseresidential contains 500 samples.
Class 19 storagetanks contains 500 samples.
Class 20 tenniscou

100%|██████████| 10500/10500 [00:20<00:00, 518.65it/s]


New size of any class 200 samples.


100%|██████████| 4200/4200 [00:08<00:00, 507.05it/s]


mean=(0.45834349999101315, 0.4672601771127858, 0.4368523198774913), std=(0.2893816882307876, 0.2808980277810426, 0.2702369471265301)


100%|██████████| 4200/4200 [00:07<00:00, 567.74it/s]
100%|██████████| 4200/4200 [00:07<00:00, 557.37it/s]
100%|██████████| 4200/4200 [00:07<00:00, 568.31it/s]
100%|██████████| 4200/4200 [00:07<00:00, 554.09it/s]
100%|██████████| 4200/4200 [00:07<00:00, 565.46it/s]


(5, 2, 2100)
------------------------------------------------------------------
checking dataset...
------------------------------------------------------------------


100%|██████████| 2100/2100 [00:03<00:00, 575.81it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 582.15it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 562.64it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 576.60it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 568.84it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 561.92it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 568.70it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 581.76it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:03<00:00, 547.61it/s]


Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.


100%|██████████| 2100/2100 [00:04<00:00, 520.69it/s]

Class 0 contains 100 samples.
Class 1 contains 100 samples.
Class 2 contains 100 samples.
Class 3 contains 100 samples.
Class 4 contains 100 samples.
Class 5 contains 100 samples.
Class 6 contains 100 samples.
Class 7 contains 100 samples.
Class 8 contains 100 samples.
Class 9 contains 100 samples.
Class 10 contains 100 samples.
Class 11 contains 100 samples.
Class 12 contains 100 samples.
Class 13 contains 100 samples.
Class 14 contains 100 samples.
Class 15 contains 100 samples.
Class 16 contains 100 samples.
Class 17 contains 100 samples.
Class 18 contains 100 samples.
Class 19 contains 100 samples.
Class 20 contains 100 samples.





## Minerals 200x200
mean=(0.291, 0.4226, 0.4654), std=(0.2227, 0.2412, 0.3168)