In [56]:
import torch
import pandas as pd
iris_main = pd.read_csv('https://raw.githubusercontent.com/pandas-dev/pandas/master/pandas/tests/data/iris.csv')

In [63]:
iris = iris_main.copy()
iris.dtypes

SepalLength    float64
SepalWidth     float64
PetalLength    float64
PetalWidth     float64
Name            object
dtype: object

In [59]:
from sklearn import preprocessing
iris['SepalLength'].mean(), iris['SepalLength'].std()

(5.843333333333335, 0.8280661279778629)

In [60]:
def mean_std_table(dataframe):
    mean_std_dict = {}
    for each_column, each_dtype in zip(iris.columns, iris.dtypes):
        if each_dtype == 'float64':
            column_mean = iris[each_column].mean()
            column_std = iris[each_column].std()
            mean_std_dict[each_column] = column_mean, column_std
            dataframe[each_column] = (dataframe[each_column]- column_mean)/(column_std**2)
    return mean_std_dict

In [65]:
mean_std_table(iris)

{'SepalLength': (5.843333333333335, 0.8280661279778629),
 'SepalWidth': (3.0540000000000007, 0.4335943113621737),
 'PetalLength': (3.7586666666666693, 1.7644204199522617),
 'PetalWidth': (1.1986666666666672, 0.7631607417008414)}

In [66]:
iris

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth,Name
0,-1.084061,2.372290,-0.757639,-1.714701,Iris-setosa
1,-1.375736,-0.287228,-0.757639,-1.714701,Iris-setosa
2,-1.667412,0.776579,-0.789761,-1.714701,Iris-setosa
3,-1.813249,0.244676,-0.725518,-1.714701,Iris-setosa
4,-1.229898,2.904193,-0.757639,-1.714701,Iris-setosa
...,...,...,...,...,...
145,1.249343,-0.287228,0.462978,1.890979,Iris-virginica
146,0.665992,-2.946745,0.398735,1.204183,Iris-virginica
147,0.957668,-0.287228,0.462978,1.375882,Iris-virginica
148,0.520155,1.840386,0.527221,1.890979,Iris-virginica


In [28]:
from torch.utils import data

class DataSpliter(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, dataframe, features, labels):
        'Initialization'
        super(DataSpliter, self).__init__()
        self.dataframe = dataframe
        self.labels = labels
        self.list_IDs = features

    def __len__(self):
        'Denotes the total number of samples'
        return self.dataframe.shape[0]

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        X = torch.FloatTensor(self.dataframe[self.list_IDs].to_numpy())[index]
        y = (torch.LongTensor(self.dataframe[self.labels].to_numpy())).squeeze(1)[index]

        return X, y

In [92]:
def train_test_split(dataframe, shuffle=False, test_decimal=.2):
    assert test_decimal < 1
    assert test_decimal > 0
    if shuffle:
        train_df = dataframe.sample(frac=(1-test_decimal))
    else:
        tt_boarder = int(dataframe.shape[0]*(1-test_decimal))
        train_df = dataframe[:tt_boarder]
    test_df = dataframe.drop(train_df.index)
    return train_df, test_df

In [96]:
iris['Name'].unique()

array(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'], dtype=object)

In [102]:
iris[iris['Name'].unique()[0] == iris['Name']].shape

(50, 5)

In [106]:
type_count = []
for each_type in iris['Name'].unique():
    type_count.append(iris[each_type == iris['Name']].shape)
min_sample = type_count.index(min(type_count))
min_sample

0

In [376]:
def train_test_split_even(dataframe, label_name, shuffle=True, test_decimal=0.2, even_num_labels=True):
    assert test_decimal < 1
    assert test_decimal > 0
    tt_boarder = int(dataframe.shape[0]*(1-test_decimal))
    if even_num_labels:
        type_count = []
        for each_type in dataframe[label_name].unique():
             type_count.append(dataframe[each_type == dataframe[label_name]].shape[0])
        min_sample = min(type_count)
        print(min_sample)
        
        train_df = pd.DataFrame(columns=dataframe.columns)
        print(tt_boarder, min_sample*len(type_count)*(1-test_decimal))
        if tt_boarder >= min_sample*len(type_count)*(1-test_decimal):
            if shuffle:
                for each_type in dataframe[label_name].unique():
                    train_df = train_df.append(dataframe[each_type == dataframe[label_name]].sample(frac=(1-test_decimal)))     
            else:
                for each_type in dataframe[label_name].unique():
                    train_df = train_df.append(dataframe[each_type == dataframe[label_name]][:tt_boarder//len(type_count)])
        else:
            if shuffle:
                for each_type in dataframe[label_name].unique():
                    train_df = train_df.append(dataframe[each_type == dataframe[label_name]].sample(frac=(1-test_decimal)))
            else:
                for each_type in dataframe[label_name].unique():
                    train_df = train_df.append(dataframe[each_type == dataframe[label_name]][:tt_boarder//len(type_count)])
        test_df = dataframe.drop(train_df.index)
    return train_df, test_df

In [379]:
x1, y1 = train_test_split_even(iris, 'Name',shuffle=True, test_decimal=0.2)

50
120 120.0


In [380]:
x1

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth,Name
26,-1.229898,1.840386,-0.693396,-1.371303,Iris-setosa
40,-1.229898,2.372290,-0.789761,-1.543002,Iris-setosa
19,-1.084061,3.968000,-0.725518,-1.543002,Iris-setosa
16,-0.646547,4.499904,-0.789761,-1.371303,Iris-setosa
8,-2.104925,-0.819131,-0.757639,-1.714701,Iris-setosa
...,...,...,...,...,...
133,0.665992,-1.351035,0.430856,0.517387,Iris-virginica
148,0.520155,1.840386,0.527221,1.890979,Iris-virginica
141,1.541019,0.244676,0.430856,1.890979,Iris-virginica
107,2.124370,-0.819131,0.816314,1.032484,Iris-virginica


In [381]:
x2, y2 = train_test_split_even(iris, 'Name',shuffle=False, test_decimal=.9)

50
14 14.999999999999996


In [382]:
x2

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth,Name
0,-1.084061,2.37229,-0.757639,-1.714701,Iris-setosa
1,-1.375736,-0.287228,-0.757639,-1.714701,Iris-setosa
2,-1.667412,0.776579,-0.789761,-1.714701,Iris-setosa
3,-1.813249,0.244676,-0.725518,-1.714701,Iris-setosa
50,1.686857,0.776579,0.30237,0.345687,Iris-versicolor
51,0.81183,0.776579,0.238127,0.517387,Iris-versicolor
52,1.541019,0.244676,0.366613,0.517387,Iris-versicolor
53,-0.50071,-4.010552,0.07752,0.173988,Iris-versicolor
100,0.665992,1.308483,0.71995,2.234377,Iris-virginica
101,-0.063196,-1.882938,0.430856,1.204183,Iris-virginica
