# Stratified Splitting

This notebook provides several tutorials on how to utilize any algorithm proposed
in the **straSplit** package to split a multi-label dataset using less explored
[stratified strategy](https://bit.ly/3s3IDA8). Please install
[anaconda](https://www.anaconda.com/) package and other modules listed
in [requirement.txt](../../requirements.txt).

## Naive splitting strategy

The naive based strategy does not address the class-imbalance problem and
neither takes into account label-correlations to split a dataset. You can
run the following command:

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.naive2split import NaiveStratification

y_name = "Ybirds_train.pkl"

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    y = lil_matrix(y[y.getnnz(axis=1) != 0][:, y.getnnz(axis=0) != 0].A)

st = NaiveStratification(shuffle=True, split_size=0.8, batch_size=1000, num_jobs=10)
training_idx, test_idx = st.fit(y=y)
training_idx, dev_idx = st.fit(y=y[training_idx])

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Extreme splitting strategy

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain.
This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.extreme2split import ExtremeStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = ExtremeStratification(swap_probability=0.1, threshold_proportion=0.1, decay=0.1,
                            shuffle=True, split_size=0.75, num_epochs=50)
training_idx, test_idx = st.fit(X=X, y=y)
training_idx, dev_idx = st.fit(X=X[training_idx], y=y[training_idx])

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Community based splitting strategy

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.comm2split import CommunityStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"
use_extreme = True

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = CommunityStratification(num_subsamples=10000, num_communities=5, walk_size=4, sigma=2,
                                swap_probability=0.1, threshold_proportion=0.1, decay=0.1,
                                shuffle=True, split_size=0.75, batch_size=100, num_epochs=50,
                                num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X, use_extreme=use_extreme)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx], use_extreme=use_extreme)

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Clustering based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.plssvd2split import ClusterStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"
use_extreme = True

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]
    
st = ClusterStratification(num_clusters=5, swap_probability=0.1, threshold_proportion=0.1,
                           decay=0.1, shuffle=True, split_size=0.75, batch_size=100,
                           num_epochs=5, lr=0.0001, num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X, use_extreme=use_extreme)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx], use_extreme=use_extreme)

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Clustering eigenvalues based splitting strategy

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input `x` and the target `y` pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.

The discriminator loss and the generator loss are similar to the ones used in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#build_the_generator).

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.eigencluster2split import ClusteringEigenStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"
use_extreme = True

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = ClusteringEigenStratification(num_subsamples=10000, num_clusters=5, sigma=2, swap_probability=0.1,
                                   threshold_proportion=0.1, decay=0.1, shuffle=True, split_size=0.75,
                                   batch_size=100, num_epochs=50, num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X, use_extreme=use_extreme)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx], use_extreme=use_extreme)

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Label Enhancement based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.enhance2split import LabelEnhancementStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"
use_extreme = True

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = LabelEnhancementStratification(num_subsamples=10000, num_communities=5, walk_size=4, sigma=2, alpha=0.2,
                                    swap_probability=0.1, threshold_proportion=0.1, decay=0.1, shuffle=True,
                                    split_size=0.75, batch_size=100, num_epochs=50, num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X, use_extreme=use_extreme)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx], use_extreme=use_extreme)

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## Active learning based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.active2split import ActiveStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = ActiveStratification(subsample_labels_size=10, acquisition_type="psp", top_k=5, swap_probability=0.1, 
                          threshold_proportion=0.1, decay=0.1, penalty='l21', alpha_elastic=0.0001, 
                          l1_ratio=0.65, alpha_l21=0.01, loss_threshold=0.05, shuffle=True,
                          split_size=0.75, batch_size=100, num_epochs=50, lr=1e-3,
                          display_interval=2, num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx])

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))

## GAN learning based splitting strategy

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate.

In [None]:
import os
import pickle as pkl
from scipy.sparse import lil_matrix
from src.utility.file_path import DATASET_PATH
from src.model.gan2split import GANStratification

X_name = "Xbirds_train.pkl"
y_name = "Ybirds_train.pkl"
use_extreme = True

file_path = os.path.join(DATASET_PATH, y_name)
with open(file_path, mode="rb") as f_in:
    y = pkl.load(f_in)
    idx = list(set(y.nonzero()[0]))
    y = y[idx]

file_path = os.path.join(DATASET_PATH, X_name)
with open(file_path, mode="rb") as f_in:
    X = pkl.load(f_in)
    X = X[idx]

st = GANStratification(dimension_size=50, num_examples2gen=20, update_ratio=1, window_size=2,
                       num_subsamples=10000, num_clusters=5, sigma=2, swap_probability=0.1,
                       threshold_proportion=0.1, decay=0.1, shuffle=True, split_size=0.75,
                       batch_size=100, max_iter_gen=30, max_iter_dis=30, num_epochs=5, lambda_gen=1e-5,
                       lambda_dis=1e-5, lr=1e-3, display_interval=30, num_jobs=2)
training_idx, test_idx = st.fit(y=y, X=X, use_extreme=use_extreme)
training_idx, dev_idx = st.fit(y=y[training_idx], X=X[training_idx], use_extreme=use_extreme)

In [None]:
print("\t>> Training set size: {0}".format(len(training_idx)))
print("\t>> Validation set size: {0}".format(len(dev_idx)))
print("\t>> Test set size: {0}".format(len(test_idx)))