# Stratified Splitting

This notebook provides several tutorials on how to utilize any algorithm proposed
in the **straSplit** package to split a multi-label dataset using less explored
[stratified strategy](https://bit.ly/3s3IDA8). Please install
[anaconda](https://www.anaconda.com/) package and other modules listed
in [requirement.txt](../../requirements.txt).

# Load modules and datasets

In [1]:
import os
os.chdir('../model')
os.getcwd()

'D:\\MultiLabel\\straSplit\\src\\model'

In [2]:
import pickle as pkl

## load utilities
from utils import DATASET_PATH,RESULT_PATH, data_properties
from utils import check_type, custom_shuffle, data_properties, LabelBinarizer

## load modules
from naive2split import NaiveStratification
from iterative2split import IterativeStratification
from extreme2split import ExtremeStratification
from plssvd2split import ClusterStratification
from eigencluster2split import ClusteringEigenStratification
from comm2split import CommunityStratification
from enhance2split import LabelEnhancementStratification
from active2split import ActiveStratification
from gan2split import GANStratification

In [3]:
split_type = "extreme"
split_size = 0.80
num_epochs = 5
num_jobs = 2
use_solver = False

In [4]:
dsname="birds"
X_name = dsname + "_X.pkl"
y_name = dsname + "_y.pkl"
file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

## Naive strategy

The naive based strategy does not address the class-imbalance problem and
neither takes into account label-correlations to split a dataset. You can
run the following command:

In [5]:
st = NaiveStratification(shuffle=True, split_size=split_size, batch_size=500,
                         num_jobs=num_jobs)
training_idx, test_idx = st.fit(y=y)

## Configuration parameters to naive based stratified multi-label dataset
   splitting:
		1. Shuffle the dataset? True
		2. Split size: 0.8
		3. Number of examples to use in each iteration: 500
		4. Number of parallel workers: 2


	>> Perform splitting...
		--> Splitting progress: 100.00%...


where *training_idx* and *test_idx* are two lists corresponding the indices
of the given data (i.e.,*y*). Let us explore the properties of the full
dataset and the resulted splits.

In [6]:
model_name = "naive2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 281
	>> Number of labels: 524
	>> Label cardinality: 1.864769
	>> Label density: 0.003559
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.067616
	>> KL difference between two full and selected examples labels distributions: 0.000883
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 70
	>> Number of labels: 130
	>> Label cardinality: 1.857143
	>> Label density: 0.014286
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.271429
	>> KL difference between two full and selected examples labels distributions: 0.016654


## Iterative strategy

As mentioned in the [paper](https://bit.ly/2QqHd4V), this alrogithm perform
iterative splitting to the dataset. These are some of the image augmentation
techniques that avoids overfitting.

CycleGAN uses a cycle consistency loss to enable training without the need
for paired data. In other words, it can translate from one domain to another
without a one-to-one mapping between the source and target domain.
This opens up the possibility to do a lot of interesting tasks like photo-enhancement,
image colorization, style transfer, etc. All you need is the source and the
target dataset (which is simply a directory of images).

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random
jittering and mirroring to the training dataset. These are some of the image
augmentation techniques that avoids overfitting.

In [7]:
st = IterativeStratification(shuffle=True, split_size=split_size)
training_idx, test_idx = st.fit(y=y)

## Configuration parameters to iteratively stratifying a multi-label
   dataset splitting:
		1. Shuffle the dataset? True
		2. Split size: 0.8


	>> Perform splitting (iterative)...
		--> Splitting progress: 100.00%...

In [8]:
model_name = "iterative2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 279
	>> Number of labels: 521
	>> Label cardinality: 1.867384
	>> Label density: 0.003584
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.068100
	>> KL difference between two full and selected examples labels distributions: 0.001513
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 72
	>> Number of labels: 133
	>> Label cardinality: 1.847222
	>> Label density: 0.013889
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.263889
	>> KL difference between two full and selected examples labels distributions: 0.024220


## Extreme strategy

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain.
This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

In [9]:
st = ExtremeStratification(swap_probability=0.1, threshold_proportion=0.1, decay=0.1,
                           shuffle=True, split_size=split_size, num_epochs=num_epochs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a large scale multi-label
   dataset splitting:
		1. A hyper-parameter for extreme stratification: 0.1
		2. A hyper-parameter for extreme stratification: 0.1
		3. A hyper-parameter for extreme stratification: 0.1
		4. Shuffle the dataset? True
		5. Split size: 0.8
		6. Number of loops over a dataset: 5


	>> Perform splitting (extreme)...
		--> Starting score: 22
		--> Splitting progress: 100.00%; score: 2.84


In [10]:
model_name = "extreme2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 274
	>> Number of labels: 513
	>> Label cardinality: 1.872263
	>> Label density: 0.003650
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.069343
	>> KL difference between two full and selected examples labels distributions: 0.002273
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 77
	>> Number of labels: 141
	>> Label cardinality: 1.831169
	>> Label density: 0.012987
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.246753
	>> KL difference between two full and selected examples labels distributions: 0.030994


## Clustering based strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [11]:
st = ClusterStratification(num_clusters=5, swap_probability=0.1, threshold_proportion=0.1,
                           decay=0.1, shuffle=True, split_size=split_size, batch_size=100,
                           num_epochs=num_epochs, lr=0.0001, num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a multi-label dataset splitting
   based on clustering the covariance of X and y using PLSSVD:
		1. Number of clusters to form: 5
		2. A hyper-parameter: 0.1
		3. A hyper-parameter: 0.1
		4. A hyper-parameter: 0.1
		5. Shuffle the dataset? True
		6. Split size: 0.8
		7. Number of examples to use in each iteration: 100
		8. Number of loops over training set: 5
		9. Learning rate: 0.0001
		10. Number of parallel workers: 2


	>> Computing the covariance of X and y using PLSSVD: 100.00%...
	>> Projecting examples onto the obtained low dimensional U orthonormal basis...
	>> Clustering the resulted low dimensional examples...
	>> Perform splitting (extreme)...
		--> Starting score: -10
		--> Splitting progress: 100.00%; score: -4.47


In [12]:
model_name = "plssvd2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 279
	>> Number of labels: 516
	>> Label cardinality: 1.849462
	>> Label density: 0.003584
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.068100
	>> KL difference between two full and selected examples labels distributions: 0.002416
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 72
	>> Number of labels: 138
	>> Label cardinality: 1.916667
	>> Label density: 0.013889
	>> Distinct label sets: 18
	>> Proportion of distinct label sets: 0.250000
	>> KL difference between two full and selected examples labels distributions: 0.044002


## Clustering eigenvalues based strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [13]:
st = ClusteringEigenStratification(num_subsamples=10000, num_clusters=5, sigma=2, swap_probability=0.1,
                                   threshold_proportion=0.1, decay=0.1, shuffle=True, split_size=split_size,
                                   batch_size=500, num_epochs=num_epochs, num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a multi-label dataset splitting
   based on clustering eigen values of the label adjacency matrix:
		1. Subsampling input size: 10000
		2. Number of communities: 5
		3. Constant that scales the amount of laplacian norm regularization: 2
		4. A hyper-parameter: 0.1
		5. A hyper-parameter: 0.1
		6. A hyper-parameter: 0.1
		7. Shuffle the dataset? True
		8. Split size: 0.8
		9. Number of examples to use in each iteration: 500
		10. Number of loops over training set: 5
		11. Number of parallel workers: 2


	>> Extracting clusters...
	>> Perform splitting (extreme)...
		--> Starting score: 3
		--> Splitting progress: 100.00%; score: -2.47


In [14]:
model_name = "eigencluster2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 274
	>> Number of labels: 509
	>> Label cardinality: 1.857664
	>> Label density: 0.003650
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.069343
	>> KL difference between two full and selected examples labels distributions: 0.002967
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 77
	>> Number of labels: 145
	>> Label cardinality: 1.883117
	>> Label density: 0.012987
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.246753
	>> KL difference between two full and selected examples labels distributions: 0.038943


## Community based splitting strategy

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

In [15]:
st = CommunityStratification(num_subsamples=20000, num_communities=5, sigma=2, swap_probability=0.1,
                             threshold_proportion=0.1, decay=0.1, shuffle=True, split_size=split_size,
                             batch_size=500, num_epochs=num_epochs, num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a multi-label dataset splitting
   based on community detection approach:
		1. Subsampling input size: 20000
		2. Number of communities: 5
		3. Constant that scales the amount of laplacian norm regularization: 2
		4. A hyper-parameter: 0.1
		5. A hyper-parameter: 0.1
		6. A hyper-parameter: 0.1
		7. Shuffle the dataset? True
		8. Split size: 0.8
		9. Number of examples to use in each iteration: 500
		10. Number of loops over training set: 5
		11. Number of parallel workers: 2


	>> Building Graph...
	>> Perform splitting (extreme)...
		--> Starting score: 17
		--> Splitting progress: 100.00%; score: 14.16


In [16]:
model_name = "comm2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 286
	>> Number of labels: 533
	>> Label cardinality: 1.863636
	>> Label density: 0.003497
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.066434
	>> KL difference between two full and selected examples labels distributions: 0.003200
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 65
	>> Number of labels: 121
	>> Label cardinality: 1.861538
	>> Label density: 0.015385
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.292308
	>> KL difference between two full and selected examples labels distributions: 0.079500


## Label enhancement based strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [17]:
st = LabelEnhancementStratification(num_subsamples=10000, num_communities=10, sigma=2, alpha=0.2,
                                    swap_probability=0.1, threshold_proportion=0.1, decay=0.1, shuffle=True,
                                    split_size=split_size, batch_size=500, num_epochs=num_epochs,
                                    num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a multi-label dataset splitting
   based on label enhancement approach:
		1. Subsampling input size: 10000
		2. Number of communities: 10
		3. Constant that scales the amount of laplacian norm regularization: 2
		4. A hyperparameter to balancing parameterwhich controls the fraction of the information inherited from the label propagation and the label matrix.: 0.2
		5. A hyper-parameter: 0.1
		6. A hyper-parameter: 0.1
		7. A hyper-parameter: 0.1
		8. Shuffle the dataset? True
		9. Split size: 0.8
		10. Number of examples to use in each iteration: 500
		11. Number of loops over training set: 5
		12. Number of parallel workers: 2


	>> Building Graph...
	>> Perform splitting (extreme)...
		--> Starting score: 33
		--> Splitting progress: 100.00%; score: -1.66


In [18]:
model_name = "enhance2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 273
	>> Number of labels: 501
	>> Label cardinality: 1.835165
	>> Label density: 0.003663
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.069597
	>> KL difference between two full and selected examples labels distributions: 0.004671
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 78
	>> Number of labels: 153
	>> Label cardinality: 1.961538
	>> Label density: 0.012821
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.243590
	>> KL difference between two full and selected examples labels distributions: 0.048639


## Active learning based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [19]:
st = ActiveStratification(subsample_labels_size=10, acquisition_type="entropy", top_k=5, calc_ads=False,
                          ads_percent=0.7, use_solver=use_solver, loss_function="hinge", swap_probability=0.1,
                          threshold_proportion=0.1, decay=0.1, penalty='elasticnet', alpha_elastic=0.0001,
                          l1_ratio=0.65, alpha_l21=0.01, loss_threshold=0.05, shuffle=True,
                          split_size=split_size, batch_size=500, num_epochs=num_epochs, lr=1e-3,
                          display_interval=1, num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to estimating examples predictive uncertainty
   scores to group example with high informativeness into training set
   using a modified approach to splitting an extreme large scale multi-
   label dataset:
		1. Subsampling labels: 10
		2. The acquisition function for estimating the predictive uncertainty: entropy
		3. Apply sklearn optimizers? False
		4. The loss function: hinge
		5. A hyper-parameter for extreme stratification: 0.1
		6. A hyper-parameter for extreme stratification: 0.1
		7. A hyper-parameter for extreme stratification: 0.1
		8. The penalty (aka regularization term): elasticnet
		9. Constant controlling the elastic term: 0.0001
		10. The elastic net mixing parameter: 0.65
		11. A cutoff threshold between two consecutive rounds: 0.05
		12. Shuffle the dataset? True
		13. Split size: 0.8
		14. Number of examples to use in each iteration: 500
		15. Number of loops over training set: 5
		16. Learning rate: 0.001
		17. How often to evaluate? 1
	

	>> Training to learn a model...
	   1)- Epoch count (1/5)...
  		<<<------------<<<------------<<<
  		>> Feed-Backward...
			--> Optimizing Theta: 100.00%...
  		>>>------------>>>------------>>>
  		>> Feed-Forward...
  		>> Predictive uncertainty using entropy...
  		>> Compute cost...
			--> New cost: 0.7251; Old cost: inf-> Calculating cost: 94.74%...
			--> Epoch 1 took 0.046 seconds...
	   2)- Epoch count (2/5)...
  		<<<------------<<<------------<<<
  		>> Feed-Backward...
			--> Optimizing Theta: 100.00%...
  		>>>------------>>>------------>>>
  		>> Feed-Forward...
  		>> Predictive uncertainty using entropy...
  		>> Compute cost...
			--> New cost: 0.7650; Old cost: 0.7251Calculating cost: 78.95%...
			--> Epoch 2 took 0.058 seconds...
	   3)- Epoch count (3/5)...
  		<<<------------<<<------------<<<
  		>> Feed-Backward...
			--> Optimizing Theta: 100.00%...
  		>>>------------>>>------------>>>
  		>> Feed-Forward...
  		>> Predictive uncertainty using entropy...
  		

In [20]:
model_name = "active2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 277
	>> Number of labels: 505
	>> Label cardinality: 1.823105
	>> Label density: 0.003610
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.068592
	>> KL difference between two full and selected examples labels distributions: 0.001761
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 74
	>> Number of labels: 149
	>> Label cardinality: 2.013514
	>> Label density: 0.013514
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.256757
	>> KL difference between two full and selected examples labels distributions: 0.020645


## GAN learning based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [21]:
st = GANStratification(dimension_size=50, num_examples2gen=20, update_ratio=1, window_size=2,
                       num_subsamples=10000, num_clusters=5, sigma=2, swap_probability=0.1,
                       threshold_proportion=0.1, decay=0.1, shuffle=True, split_size=split_size,
                       batch_size=1000, max_iter_gen=num_epochs, max_iter_dis=num_epochs, 
                       num_epochs=num_epochs, lambda_gen=1e-5, lambda_dis=1e-5, lr=1e-3, 
                       display_interval=2, num_jobs=num_jobs)
training_idx, test_idx = st.fit(X=X, y=y)

## Configuration parameters to stratifying a multi-label dataset splitting
   based on clustering embeddings obtained from GAN2Embed model:
		1. The dimension size of embeddings: 50
		2. The number of samples for the generator.: 20
		3. Subsampling input size: 10000
		4. Number of communities: 5
		5. Constant that scales the amount of laplacian norm regularization: 2
		6. Updating ratio when choose the trees: 2
		7. Window size to skip.: 2
		8. A hyper-parameter: 0.1
		9. A hyper-parameter: 0.1
		10. A hyper-parameter: 0.1
		11. Shuffle the dataset? True
		12. Split size: 0.8
		13. Number of examples to use in each iteration: 1000
		14. The number of inner loops for the generator: 5
		15. The number of inner loops for the discriminator: 5
		16. Number of loops over training set: 5
		17. The l2 loss regulation weight for the generator: 1e-05
		18. The l2 loss regulation weight for the discriminator: 1e-05
		19. Learning rate: 0.001
		20. Sample new nodes for the discriminator for every 

	>> Building Graph...
	>> Building BFS-trees...
	>> Building GAN model...





	>> Training GAN model...
	>> Extracting clusters...00.00%...
	>> Perform splitting (extreme)...
		--> Starting score: 66
		--> Splitting progress: 100.00%; score: 9.86


In [22]:
model_name = "gan2split"
data_properties(y=y.toarray(), selected_examples=training_idx, num_tails=1,
                display_full_properties=True, dataset_name=dsname,
                model_name=model_name, split_set_name="training",
                rspath=RESULT_PATH)
data_properties(y=y.toarray(), selected_examples=test_idx, num_tails=1,
                display_full_properties=False, dataset_name=dsname,
                model_name=model_name, split_set_name="test",
                rspath=RESULT_PATH, mode="a")

## DATA PROPERTIES for birds...
	>> Number of examples: 351
	>> Number of labels: 654
	>> Label cardinality: 1.863248
	>> Label density: 0.002849
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.054131
	>> Number of tail labels of size 1: 0
	>> Number of dominant labels of size 2: 19
## SELECTED (training set) DATA PROPERTIES for birds...
	>> Number of examples: 285
	>> Number of labels: 533
	>> Label cardinality: 1.870175
	>> Label density: 0.003509
	>> Distinct label sets: 19
	>> Proportion of distinct label sets: 0.066667
	>> KL difference between two full and selected examples labels distributions: 0.002121
## SELECTED (test set) DATA PROPERTIES for birds...
	>> Number of examples: 66
	>> Number of labels: 121
	>> Label cardinality: 1.833333
	>> Label density: 0.015152
	>> Distinct label sets: 18
	>> Proportion of distinct label sets: 0.272727
	>> KL difference between two full and selected examples labels distributions: 0.043681
