In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict

from sdv.tabular import CTGAN, GaussianCopula, TVAE
from sdv.evaluation import evaluate

from sklearn.ensemble import ExtraTreesClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, precision_recall_fscore_support, matthews_corrcoef, f1_score

In [2]:
import torch

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cuda')

In [5]:
import warnings
warnings.filterwarnings("ignore")

In [6]:
train = pd.read_csv('../data/binned_train.csv')
test = pd.read_csv('../data/binned_test.csv')

x_test = test.drop('G3', axis=1)
y_test = test['G3']

In [7]:
train['G3'].value_counts()

high       313
upper       98
passing     63
fail        12
Name: G3, dtype: int64

In [8]:
copulas = {'fail': GaussianCopula(),
           'passing': GaussianCopula(),
           'high': GaussianCopula(),
           'upper': GaussianCopula()}

In [9]:
class SyntheticTrainingSet():
    def __init__(self, data=None, target_label=None, classes=None, random_state=1, up_sample=0, synthetic_only=False, model='gaussian_copula'):
        self.data = data.copy()
        self.target_label = target_label
        self.classes = classes
        self.random_state = random_state
        self.up_sample = up_sample
        self.synthetic_only = synthetic_only
        self._synthetic_data = dict()
        self.models = dict()
        
    def fit(self):
        """
        fit does not conform to the standard scikit-learn api call.
        Instead, it is called as a stand alone post-initiation to create a set of pandas dataframes that can be used for other learning
        """
        for _class in self.classes:
            self.models[_class] = GaussianCopula()
        self.class_data = {_class:self.data.loc[self.data[self.target_label] == _class] for _class in self.classes}
        
        up_sample = self.data[self.target_label].value_counts()
        self._up_sample = up_sample.max() - up_sample + self.up_sample
        
        self.X = self.data.drop(self.target_label, axis=1)
        self.y = self.data[self.target_label]
        
        self.sample_data()
        
        if self.synthetic_only:
            self.df = self.synthetic_df
        else:
            self.df = pd.concat([self.data, self.synthetic_df])
 
    
    def sample_data(self):
        """
        sample_data is called to construct the internal self.synthetic_df variable for repeated access in the class
        """
        self.synthetic_df = list()
        for _class, _model in self.models.items():
            class_data = self.data[self.data[self.target_label] == _class]
            _model.fit(class_data)
            _sample_size = self._up_sample.loc[_class]
            print(_sample_size)
            if _sample_size > 0:
                synth_data = self.sample_copulas(_model, _sample_size, self.random_state)
                self._synthetic_data[_class] = synth_data
                self.synthetic_df.append(synth_data)
            else:
                self._synthetic_data[_class] = pd.DataFrame(columns = self.data.columns)
                
        self.synthetic_df = pd.concat(self.synthetic_df)

    def resample_data(self, sample_size, balance_data):
        """
        Allows resampling of the models after they've already been fitted and set to the same random state
        sample_size: how many samples to generage
        balance_data: whether or not to create a balanced dataset with the original data added in
        """
        resample_df = list()
        if balance_data:
            resample = self.data[self.target_label].value_counts()
            resample = resample.max() - resample + sample_size
        else:
            resample = pd.Series([sample_size] * len(self.classes), index=self.classes)
            
        for _class, _model in self.models.items():
            resample_size = resample.loc[_class]
            if resample_size > 0:
                synth_data = self.sample_copulas(_model, resample_size, self.random_state)
                resample_df.append(synth_data)
        
        if balance_data:
            resample_df.append(self.data)
        resample_df = pd.concat(resample_df)
        return resample_df
        
    
    def sample_copulas(self, _model, sample_size, random_state, offset_sample=0):
        """
        Sample from a dataset, original or synthetic, to produce a fixed data size
        sample_size: The desired output when combined with any prior data
        random_state: Seed value to set the random state
        offset_sample: The offset provided to adjust to any existing data
        """
        if offset_sample == sample_size:
            return pd.DataFrame(columns=sample.columns)
        _to_sample = sample_size - offset_sample
        print(_to_sample)
        np.random.seed(random_state)
        return _model.sample(_to_sample)
    
    def split_synthetic_into_train(self, _data, target_label):
        """
        _data: data with both X and y in it
        target_label: the label to separate X from y for training
        """
        return self.X, self.y

        
    def evaluate(self):
        """
        evaluate calls the existing evaluate() function from Synthetic Data Vault and returns the verbose results from it rather than an aggregate score
        """
        for _class in self.classes:
            class_data = self.data.loc[ self.data[self.target_label] == _class]
            synthetic_data = self._synthetic_data[_class]
            report = evaluate(synthetic_data, class_data, aggregate=False).round(3)
            print(f'{_class}\n{report}\n---------------\n')

In [10]:
trainer = SyntheticTrainingSet(train, 'G3', ['upper','passing','fail','high'])

In [11]:
trainer.fit()

215
215
250
250
301
301
0


In [12]:
balanced_df = trainer.df

In [13]:
over_sampled = trainer.resample_data(1000, True)

1215
1250
1301
1000


In [15]:
synthetic_only = trainer.resample_data(1313, False)

1313
1313
1313
1313


In [16]:
large_oversampled = trainer.resample_data(4000, True)

4215
4250
4301
4000


In [17]:
large_synthetic = trainer.resample_data(4313, False)

4313
4313
4313
4313


In [18]:
params = {'n_estimators':[50,100,150,200,250,300,350,400],
          'max_depth':[None],
          'random_state':[42],
          'class_weight':['balanced','balanced_subsample', None],
          'max_features':['auto','sqrt','log2']}

In [19]:
original_etc = ExtraTreesClassifier()
balanced_etc = ExtraTreesClassifier()
oversampled_etc = ExtraTreesClassifier()
synthetic_etc = ExtraTreesClassifier()
large_oversampled_etc = ExtraTreesClassifier()
large_synthetic_etc = ExtraTreesClassifier()

In [20]:
original_clf = GridSearchCV(original_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')
balanced_clf = GridSearchCV(balanced_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')
oversampled_clf = GridSearchCV(oversampled_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')
synthetic_clf = GridSearchCV(synthetic_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')
large_oversampled_clf = GridSearchCV(large_oversampled_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')
large_synthetic_clf = GridSearchCV(large_synthetic_etc, params, n_jobs=-1, verbose=4, scoring='f1_weighted')

In [21]:
original_clf.fit(train.drop('G3', axis=1), train['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

In [22]:
balanced_clf.fit(balanced_df.drop('G3', axis=1), balanced_df['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

In [23]:
oversampled_clf.fit(over_sampled.drop('G3', axis=1), over_sampled['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

In [24]:
synthetic_clf.fit(synthetic_only.drop('G3', axis=1), synthetic_only['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

In [25]:
large_oversampled_clf.fit(large_oversampled.drop('G3', axis=1), large_oversampled['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

In [26]:
large_synthetic_clf.fit(large_synthetic.drop('G3', axis=1), large_synthetic['G3'])

Fitting 5 folds for each of 72 candidates, totalling 360 fits


GridSearchCV(estimator=ExtraTreesClassifier(), n_jobs=-1,
             param_grid={'class_weight': ['balanced', 'balanced_subsample',
                                          None],
                         'max_depth': [None],
                         'max_features': ['auto', 'sqrt', 'log2'],
                         'n_estimators': [50, 100, 150, 200, 250, 300, 350,
                                          400],
                         'random_state': [42]},
             scoring='f1_weighted', verbose=4)

# Evaluation

original
0.1547459945488794
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.67      0.93      0.78       105
     passing       0.29      0.10      0.14        21
       upper       0.50      0.15      0.23        33

    accuracy                           0.64       163
   macro avg       0.36      0.30      0.29       163
weighted avg       0.57      0.64      0.57       163

-----------------------------------

Combined Data Based on Full Data Set

-----------------------------------
original
0.5298076142632129
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.76      0.98      0.85       105
     passing       0.64      0.33      0.44        21
       upper       1.00      0.48      0.65        33

    accuracy                           0.77       163
   macro avg       0.60      0.45      0.49       163
weighted avg       0.77      0.77      0.74       163

-----------------------------------
large
0.4555326806930739
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.73      0.99      0.84       105
     passing       0.62      0.24      0.34        21
       upper       1.00      0.36      0.53        33

    accuracy                           0.74       163
   macro avg       0.59      0.40      0.43       163
weighted avg       0.75      0.74      0.69       163

-----------------------------------
very_large
0.34866945512163566
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.70      0.98      0.81       105
     passing       0.67      0.19      0.30        21
       upper       0.89      0.24      0.38        33

    accuracy                           0.71       163
   macro avg       0.56      0.35      0.37       163
weighted avg       0.71      0.71      0.64       163

-----------------------------------
extra_large
0.3388472987253949
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.69      1.00      0.81       105
     passing       0.75      0.14      0.24        21
       upper       1.00      0.18      0.31        33

    accuracy                           0.70       163
   macro avg       0.61      0.33      0.34       163
weighted avg       0.74      0.70      0.62       163

-----------------------------------

balanced with large resampling

-----------------------------------
0.4891806574022487
              precision    recall  f1-score   support

        fail       0.12      0.25      0.17         4
        high       0.79      0.84      0.81       105
     passing       0.47      0.43      0.45        21
       upper       0.88      0.67      0.76        33

    accuracy                           0.74       163
   macro avg       0.57      0.55      0.55       163
weighted avg       0.75      0.74      0.74       163

-----------------------------------

Class Specific Sampling

-----------------------------------

original
0.1547459945488794
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.67      0.93      0.78       105
     passing       0.29      0.10      0.14        21
       upper       0.50      0.15      0.23        33

    accuracy                           0.64       163
   macro avg       0.36      0.30      0.29       163
weighted avg       0.57      0.64      0.57       163

-----------------------------------
balanced
0.2777622216346843
              precision    recall  f1-score   support

        fail       1.00      0.50      0.67         4
        high       0.71      0.71      0.71       105
     passing       0.44      0.52      0.48        21
       upper       0.42      0.39      0.41        33

    accuracy                           0.62       163
   macro avg       0.64      0.53      0.57       163
weighted avg       0.63      0.62      0.62       163

-----------------------------------
oversampled
0.29021646897293163
              precision    recall  f1-score   support

        fail       0.67      0.50      0.57         4
        high       0.73      0.67      0.70       105
     passing       0.52      0.62      0.57        21
       upper       0.36      0.42      0.39        33

    accuracy                           0.61       163
   macro avg       0.57      0.55      0.56       163
weighted avg       0.63      0.61      0.61       163

-----------------------------------
large_oversampled
0.31025806701850306
              precision    recall  f1-score   support

        fail       0.67      0.50      0.57         4
        high       0.76      0.59      0.66       105
     passing       0.33      0.62      0.43        21
       upper       0.46      0.55      0.50        33

    accuracy                           0.58       163
   macro avg       0.55      0.56      0.54       163
weighted avg       0.64      0.58      0.60       163

-----------------------------------
synthetic
0.23533468901341467
              precision    recall  f1-score   support

        fail       0.75      0.75      0.75         4
        high       0.75      0.47      0.58       105
     passing       0.25      0.62      0.36        21
       upper       0.36      0.45      0.40        33

    accuracy                           0.49       163
   macro avg       0.53      0.57      0.52       163
weighted avg       0.61      0.49      0.52       163

-----------------------------------
large_synthetic
0.20655751873176692
              precision    recall  f1-score   support

        fail       0.67      0.50      0.57         4
        high       0.73      0.43      0.54       105
     passing       0.25      0.62      0.35        21
       upper       0.36      0.48      0.41        33

    accuracy                           0.47       163
   macro avg       0.50      0.51      0.47       163
weighted avg       0.59      0.47      0.49       163

-----------------------------------

In [27]:
for _clf_name, _clf in zip(['original','balanced','oversampled','large_oversampled','synthetic','large_synthetic'],
                           [original_clf, balanced_clf, oversampled_clf, large_oversampled_clf, synthetic_clf, large_synthetic_clf]):
    print(_clf_name)
    _preds = _clf.predict(x_test)
    print(matthews_corrcoef(y_test, _preds))
    print(classification_report(y_test, _preds))
    print('-----------------------------------')

original
0.1547459945488794
              precision    recall  f1-score   support

        fail       0.00      0.00      0.00         4
        high       0.67      0.93      0.78       105
     passing       0.29      0.10      0.14        21
       upper       0.50      0.15      0.23        33

    accuracy                           0.64       163
   macro avg       0.36      0.30      0.29       163
weighted avg       0.57      0.64      0.57       163

-----------------------------------
balanced
0.2777622216346843
              precision    recall  f1-score   support

        fail       1.00      0.50      0.67         4
        high       0.71      0.71      0.71       105
     passing       0.44      0.52      0.48        21
       upper       0.42      0.39      0.41        33

    accuracy                           0.62       163
   macro avg       0.64      0.53      0.57       163
weighted avg       0.63      0.62      0.62       163

-----------------------------------
ov

In [28]:
train['G3'].value_counts()

high       313
upper       98
passing     63
fail        12
Name: G3, dtype: int64