In [1]:
import sys
sys.path.append('../../../Fedot.Industrial')

import os
import numpy as np
from torch.utils.data import Subset
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from core.architecture.datasets.splitters import get_dataset_mean_std, split_data, undersampling, dataset_info

DATASETS_ROOT = '/media/n31v/data/datasets'

## MNIST
mean=(0.13066,), std=(0.3081,)

In [13]:
from torchvision.datasets import MNIST

mnist_ds = MNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)
mean, std = get_dataset_mean_std(mnist_ds)
dataset_info(mnist_ds)
print(f'{mean=}, {std=}')

100%|██████████| 60000/60000 [00:06<00:00, 8974.01it/s]
100%|██████████| 60000/60000 [00:04<00:00, 14236.07it/s]

Class 0 contains 5923 samples.
Class 1 contains 6742 samples.
Class 2 contains 5958 samples.
Class 3 contains 6131 samples.
Class 4 contains 5842 samples.
Class 5 contains 5421 samples.
Class 6 contains 5918 samples.
Class 7 contains 6265 samples.
Class 8 contains 5851 samples.
Class 9 contains 5949 samples.
mean = (0.13066047797803165,), std = (0.3081078048756658,)





In [None]:
mnist_ds = undersampling(mnist_ds, n=5000)

folds = []
for i in range(5):
    f1, f2 = split_data(mnist_ds, 2)
    folds.append(np.array([mnist_ds.indices[f1], mnist_ds.indices[f2]]))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'MNIST', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'MNIST', 'folds.npy'))

mnist_ds = MNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=mnist_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=mnist_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## FashionMNIST
mean=(0.286,), std=(0.353,)

In [None]:
from torchvision.datasets import FashionMNIST

fmnist_ds = FashionMNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
    download=True,
)
mean, std = get_dataset_mean_std(fmnist_ds)
dataset_info(fmnist_ds)
print(f'{mean=}, {std=}')

In [None]:
fmnist_ds = undersampling(fmnist_ds, n=5000)

folds = []
for i in range(5):
    f1, f2 = split_data(fmnist_ds, 2)
    folds.append(np.array([fmnist_ds.indices[f1], fmnist_ds.indices[f2]]))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'FashionMNIST', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'FashionMNIST', 'folds.npy'))

mnist_ds = FashionMNIST(
    root=DATASETS_ROOT,
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=mnist_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=mnist_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## CIFAR10
mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)

In [16]:
from torchvision.datasets import CIFAR10

cifar10_ds = CIFAR10(
    root=os.path.join(DATASETS_ROOT, 'CIFAR10'),
    transform=ToTensor(),
)
mean, std = get_dataset_mean_std(cifar10_ds)
dataset_info(cifar10_ds)
print(f'{mean=}, {std=}')

100%|██████████| 50000/50000 [00:06<00:00, 7721.50it/s]
100%|██████████| 50000/50000 [00:04<00:00, 11767.78it/s]

Class 0 contains 5000 samples.
Class 1 contains 5000 samples.
Class 2 contains 5000 samples.
Class 3 contains 5000 samples.
Class 4 contains 5000 samples.
Class 5 contains 5000 samples.
Class 6 contains 5000 samples.
Class 7 contains 5000 samples.
Class 8 contains 5000 samples.
Class 9 contains 5000 samples.
mean = (0.4913996927399561, 0.4821584222899936, 0.4465309280202538), std = (0.24703223297351337, 0.2434851287896555, 0.26158784042441807)





In [None]:
folds = []
for i in range(5):
    folds.append(np.array(split_data(cifar10_ds, 2)))
folds = np.array(folds)
np.save(os.path.join(DATASETS_ROOT, 'CIFAR10', 'folds'), folds)
print(folds.shape)

In [None]:
folds = np.load(os.path.join(DATASETS_ROOT, 'CIFAR10', 'folds.npy'))

cifar10_ds = CIFAR10(
    root=os.path.join(DATASETS_ROOT, 'CIFAR10'),
    transform=ToTensor(),
)

for i in range(5):
    fold1 = Subset(dataset=cifar10_ds, indices=folds[i, 0, :])
    dataset_info(fold1)
    fold2 = Subset(dataset=cifar10_ds, indices=folds[i, 1, :])
    dataset_info(fold2)

## ImageFolder

In [2]:
from torchvision.datasets import ImageFolder

def check_dataset(dataset: str):
    folds = np.load(os.path.join(DATASETS_ROOT, dataset, 'folds.npy'))
    ds = ImageFolder(
        root=os.path.join(DATASETS_ROOT, dataset),
        transform=ToTensor(),
    )

    for i in range(5):
        fold1 = Subset(dataset=ds, indices=folds[i, 0, :])
        dataset_info(fold1, verbose=True)
        fold2 = Subset(dataset=ds, indices=folds[i, 1, :])
        dataset_info(fold2, verbose=True)

def prepare_dataset(dataset: str, check: bool = True):
    ds = ImageFolder(
        root=os.path.join(DATASETS_ROOT, dataset),
        transform=ToTensor(),
    )
    mean, std = get_dataset_mean_std(ds)
    classes = dataset_info(ds)
    print('------------------------------------------------------------------')
    print('dataset info:')
    print('------------------------------------------------------------------')
    idx_to_class = {v: k for k, v in ds.class_to_idx.items()}
    for k, v in classes.items():
        print(f"Class {k} {idx_to_class[k]} contains {v} samples.")
    print(f'{mean=}, {std=}')
    print('------------------------------------------------------------------')
    print('undersamling...')
    print('------------------------------------------------------------------')
    n = min(classes.values())
    n = n if n % 2 == 0 else n - 1
    ds = undersampling(ds, n=n)
    folds = []
    for i in range(5):
        f1, f2 = split_data(ds, 2)
        folds.append(np.array([ds.indices[f1], ds.indices[f2]]))
    folds = np.array(folds)
    np.save(os.path.join(DATASETS_ROOT, dataset, 'folds'), folds)
    print(folds.shape)
    if check:
        print('------------------------------------------------------------------')
        print('checking dataset...')
        print('------------------------------------------------------------------')
        check_dataset(dataset)

## Minerals
mean=(0.291, 0.4226, 0.4654), std=(0.2227, 0.2412, 0.3168)

In [None]:
prepare_dataset('minerals')

## Minerals (150x150)
mean=(0.4186, 0.4301, 0.4217), std=(0.228, 0.217, 0.2543)

In [None]:
prepare_dataset('minerals_classification')

## Minerals 200
mean=(0.444, 0.562, 0.556), std=(0.207, 0.235, 0.231)

In [None]:
prepare_dataset('minerals200')

In [3]:
prepare_dataset('New_dataset_big')

100%|██████████| 18495/18495 [02:01<00:00, 152.27it/s] 
100%|██████████| 18495/18495 [00:14<00:00, 1262.47it/s]


------------------------------------------------------------------
dataset info:
------------------------------------------------------------------
Class 0 01_Almandin contains 624 samples.
Class 1 02_Amazonit contains 604 samples.
Class 2 03_Apatit contains 829 samples.
Class 3 04_Barit contains 782 samples.
Class 4 05_Berill contains 660 samples.
Class 5 06_Biotit contains 795 samples.
Class 6 07_Granat contains 888 samples.
Class 7 08_Diopsit contains 633 samples.
Class 8 09_Gold contains 870 samples.
Class 9 10_Calcit contains 796 samples.
Class 10 11_Quarz_1 contains 746 samples.
Class 11 12_Quarz_2 contains 856 samples.
Class 12 13_Kianit contains 635 samples.
Class 13 14_Kordierit contains 767 samples.
Class 14 15_Korund contains 667 samples.
Class 15 16_Magnetit contains 575 samples.
Class 16 17_Microclin contains 799 samples.
Class 17 18_Muscovit contains 851 samples.
Class 18 19_Olivin contains 922 samples.
Class 19 20_Pirit contains 670 samples.
Class 20 21_Fluorit contains 

100%|██████████| 18495/18495 [00:14<00:00, 1279.70it/s]


New size of any class 574 samples.


100%|██████████| 13776/13776 [00:11<00:00, 1196.29it/s]
100%|██████████| 13776/13776 [00:11<00:00, 1195.75it/s]
100%|██████████| 13776/13776 [00:11<00:00, 1169.62it/s]
100%|██████████| 13776/13776 [00:12<00:00, 1083.53it/s]
100%|██████████| 13776/13776 [00:11<00:00, 1220.75it/s]


(5, 2, 6888)
------------------------------------------------------------------
checking dataset...
------------------------------------------------------------------


100%|██████████| 6888/6888 [00:05<00:00, 1219.27it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1211.61it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1224.65it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1201.39it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1215.92it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1211.06it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1229.75it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1216.80it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1234.72it/s]


Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.


100%|██████████| 6888/6888 [00:05<00:00, 1172.12it/s]

Class 0 contains 287 samples.
Class 1 contains 287 samples.
Class 2 contains 287 samples.
Class 3 contains 287 samples.
Class 4 contains 287 samples.
Class 5 contains 287 samples.
Class 6 contains 287 samples.
Class 7 contains 287 samples.
Class 8 contains 287 samples.
Class 9 contains 287 samples.
Class 10 contains 287 samples.
Class 11 contains 287 samples.
Class 12 contains 287 samples.
Class 13 contains 287 samples.
Class 14 contains 287 samples.
Class 15 contains 287 samples.
Class 16 contains 287 samples.
Class 17 contains 287 samples.
Class 18 contains 287 samples.
Class 19 contains 287 samples.
Class 20 contains 287 samples.
Class 21 contains 287 samples.
Class 22 contains 287 samples.
Class 23 contains 287 samples.





In [4]:
prepare_dataset('New_dataset_small')

100%|██████████| 9617/9617 [00:15<00:00, 606.08it/s]
100%|██████████| 9617/9617 [00:10<00:00, 957.54it/s] 


------------------------------------------------------------------
dataset info:
------------------------------------------------------------------
Class 0 01_Almandin contains 305 samples.
Class 1 02_Amazonit contains 318 samples.
Class 2 03_Apatit contains 354 samples.
Class 3 04_Barit contains 337 samples.
Class 4 05_Berill contains 368 samples.
Class 5 06_Biotit contains 362 samples.
Class 6 07_Granat contains 362 samples.
Class 7 08_Diopsit contains 406 samples.
Class 8 09_Gold contains 334 samples.
Class 9 10_Calcit contains 368 samples.
Class 10 11_Quarz_1 contains 388 samples.
Class 11 12_Quarz_2 contains 424 samples.
Class 12 13_Kianit contains 391 samples.
Class 13 14_Kordierit contains 446 samples.
Class 14 15_Korund contains 397 samples.
Class 15 16_Magnetit contains 356 samples.
Class 16 17_Microclin contains 406 samples.
Class 17 18_Muscovit contains 608 samples.
Class 18 19_Olivin contains 471 samples.
Class 19 20_Pirit contains 411 samples.
Class 20 21_Fluorit contains 

100%|██████████| 9617/9617 [00:09<00:00, 1063.18it/s]


New size of any class 304 samples.


100%|██████████| 7296/7296 [00:06<00:00, 1140.79it/s]
100%|██████████| 7296/7296 [00:07<00:00, 1028.58it/s]
100%|██████████| 7296/7296 [00:07<00:00, 945.86it/s] 
100%|██████████| 7296/7296 [00:06<00:00, 1052.15it/s]
 58%|█████▊    | 4255/7296 [02:12<01:34, 32.07it/s]  


KeyboardInterrupt: 