In [None]:
from torch.utils.data import Dataset
import os
import pandas as pd
import torch
import numpy as np

class WineDataset(Dataset):
    def __init__(self, root, wine_type, n_classes=10):
        self.root = os.path.expanduser(root)
        self.wine_type = wine_type
        self.n_classes = n_classes

        self.data = pd.read_csv(os.path.join(self.root, "winequality-{}.csv".format(self.wine_type)),
                                      sep=";",
                                      header=1)
        self.data = torch.Tensor(self.data.values)
        self.X = self.data[:, :-1]
        self.y = self.data[:, -1]
        for i in range(n_classes):
            self.y[np.logical_and(self.y <= (i + 1) * 10 / n_classes, self.y >= i * 10 / n_classes)] = i

    def __getitem__(self, index):
        Xi = self.X[index]
        yi = self.y[index].type(torch.LongTensor)

        return Xi, yi

    def __len__(self):
        return self.data.shape[0]

In [None]:
import torch.nn as nn

# Neural Network Model (1 hidden layer)
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [None]:
import numpy as np
import torch
from torch import nn
import pandas as pd
from mlp import MLP
from sklearn.model_selection import train_test_split
from skorch import NeuralNetClassifier
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from scipy.stats import uniform
import json
import os
import copy
from sklearn.utils import resample

root = "./../data"

white_wine_data = pd.read_csv(root+"/winequality/winequality-white.csv",
                                      sep=";",
                                      header=0)
# white_wine_data = white_wine_data.values

red_wine_data = pd.read_csv(root+"/winequality/winequality-red.csv",
                                      sep=";",
                                      header=0)
# red_wine_data = red_wine_data.values

both_wine_data = pd.concat((white_wine_data, red_wine_data))

datasets = {"red_wine":red_wine_data, "white_wine":white_wine_data, "both_wine":both_wine_data}

input_size = 11

os.makedirs("./results", exist_ok=True)

import datetime
results = open("./results/wine_mlp_results_{}".format(datetime.datetime.now().strftime("%Y-%m-%d %H-%M-%S")), "w+")


def convert_ylabel_into_k_categories_equilibre(wine_data, equilibre, k_categ=3):
    if (k_categ == 2):
        np.place(wine_data["quality"].values, wine_data["quality"].values < 6, 0)
        np.place(wine_data["quality"].values, wine_data["quality"].values >= 6, 1)

    else:
        np.place(wine_data["quality"].values, wine_data["quality"].values < 6, 0)
        np.place(wine_data["quality"].values, wine_data["quality"].values == 6, 1)
        np.place(wine_data["quality"].values, wine_data["quality"].values > 6, 2)
        if not equilibre:
            return wine_data

    unique_label, counts_elements = np.unique(wine_data["quality"].values, return_counts=True)
    my_dict = dict(zip(unique_label, counts_elements))
    print(my_dict)
    label_of_minor_class = min(my_dict.keys(), key=(lambda k: my_dict[k]))
    size_of_minor_class = my_dict[label_of_minor_class]

    print("label_of_minor_class :", label_of_minor_class, "size_of_minor_class", size_of_minor_class)

    df_minority = wine_data[wine_data["quality"].values == label_of_minor_class]

    if (k_categ == 2):
        df_non_minority_1 = wine_data[wine_data["quality"].values != label_of_minor_class]

        # Downsample majority class
        df_non_minority_1_downsampled = resample(df_non_minority_1,
                                                 replace=False,  # sample without replacement
                                                 n_samples=size_of_minor_class,  # to match minority class
                                                 random_state=123)  # reproducible results

        # Combine minority class with downsampled majority class
        df_downsampled = pd.concat([df_non_minority_1_downsampled, df_minority])

    else:
        df_non_minority_1 = wine_data[wine_data["quality"].values == 0]
        df_non_minority_2 = wine_data[wine_data["quality"].values == 1]

        # Downsample majority class
        df_non_minority_1_downsampled = resample(df_non_minority_1,
                                                 replace=False,  # sample without replacement
                                                 n_samples=size_of_minor_class,  # to match minority class
                                                 random_state=123)  # reproducible results

        # Downsample majority class
        df_non_minority_2_downsampled = resample(df_non_minority_2,
                                                 replace=False,  # sample without replacement
                                                 n_samples=size_of_minor_class,  # to match minority class
                                                 random_state=123)  # reproducible results

        # Combine minority class with downsampled majority class
        df_downsampled = pd.concat([df_non_minority_1_downsampled, df_non_minority_2_downsampled, df_minority])

    # new target values data after downsampling
    return df_downsampled


def train_and_test(params):
    print(params)
    num_classes = params[0]
    equilibre = params[1]
    dataset_name = params[2]

    dataset = datasets[dataset_name].copy()
    if num_classes < 10:
        dataset = convert_ylabel_into_k_categories_equilibre(dataset, equilibre, num_classes).values
    else:
        dataset = dataset.values
    X = dataset[:, :-1].astype(np.float32)
    y = dataset[:, -1].astype(np.int64)

    train_x, test_x, train_y, test_y = train_test_split(X,y,test_size=0.33, random_state=2)
    net = NeuralNetClassifier(
        MLP,
        criterion=nn.CrossEntropyLoss,
        max_epochs=30,
        module__input_size=input_size,
        module__num_classes=num_classes,
        device='cuda'
    )
    params = {
        'net__lr': uniform(loc=0, scale=0.2),
        'net__module__hidden_size': randint(100, 1000),
        'net__optimizer__weight_decay': uniform(loc=0, scale=0.2),
        'net__batch_size': randint(10, 200)
    }

    model = Pipeline(steps=[("scaler",StandardScaler()), ("net",net)])

    rs = RandomizedSearchCV(model, params, refit=True, cv=3, scoring='accuracy', n_iter=100, n_jobs=1)

    rs.fit(train_x, train_y)
    name = "{}_{}_{}".format(num_classes, "equilibre" if equilibre else "desequilibre", dataset_name)
    os.makedirs("./results/{}".format(name), exist_ok=True)
    results.write("{}\n".format(name))
    results.write("train: {}\n{}\n".format(rs.best_score_, rs.best_params_))
    print(rs.best_score_, rs.best_params_)

    test_score = rs.score(test_x, test_y)
    results.write("test: {}\n\n".format(test_score))
    print(test_score)
    #net.initialize()

    #net.save_params(f_params="./results/{}/model.pkl".format(name), f_optimizer="./results/{}/opt.pkl".format(name), f_history="./results/{}/history.json".format(name))

    f = open("./results/{}/hyperparams.json".format(name,name), "w+")
    json.dump(rs.best_params_, f)
    f.close()

    del rs
    del model
    del net
    torch.cuda.empty_cache()

params = [[2,True,"red_wine"], [2,True,"white_wine"], [2,True,"both_wine"],
 [3,True,"red_wine"], [3,True,"white_wine"], [3,True,"both_wine"],
 [3,False,"red_wine"], [3,False,"white_wine"], [3,False,"both_wine"],
 [10, False, "red_wine"], [10, False, "white_wine"], [10, False, "both_wine"]]

for param in params:
    train_and_test(param)

results.close()

     27        0.6095       0.7820        0.5832  0.0151
     28        0.6095       0.7820        0.5832  0.0139
     29        0.6095       0.7820        0.5832  0.0138
     30        0.6096       0.7820        0.5832  0.0134
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6374[0m       [32m0.7836[0m        [35m0.4973[0m  0.0140
      2        [36m0.5382[0m       [32m0.7910[0m        [35m0.4853[0m  0.0165
      3        0.5404       0.7910        0.5026  0.0165
      4        0.5511       0.7910        0.5218  0.0161
      5        0.5628       [32m0.8060[0m        0.5377  0.0161
      6        0.5729       [32m0.8134[0m        0.5497  0.0163
      7        0.5808       [32m0.8209[0m        0.5581  0.0157
      8        0.5868       0.8209        0.5641  0.0157
      9        0.5910       0.8134        0.5683  0.0184
     10        0.5941       0.8060        0.5712  0.0181
     11

     23        [36m0.4847[0m       0.7910        [35m0.4264[0m  0.0109
     24        [36m0.4836[0m       0.7985        [35m0.4258[0m  0.0110
     25        [36m0.4826[0m       0.7985        [35m0.4252[0m  0.0148
     26        [36m0.4816[0m       0.7985        [35m0.4247[0m  0.0112
     27        [36m0.4807[0m       0.7985        [35m0.4242[0m  0.0114
     28        [36m0.4798[0m       0.7985        [35m0.4238[0m  0.0113
     29        [36m0.4789[0m       0.7985        [35m0.4235[0m  0.0102
     30        [36m0.4781[0m       0.7985        [35m0.4231[0m  0.0098
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6451[0m       [32m0.7537[0m        [35m0.5859[0m  0.0093
      2        [36m0.5865[0m       [32m0.7836[0m        [35m0.5422[0m  0.0104
      3        [36m0.5600[0m       0.7761        [35m0.5183[0m  0.0098
      4        [36m0.5451[0m       0.7836

      9        [36m0.4781[0m       0.7687        [35m0.4657[0m  0.0086
     10        [36m0.4757[0m       0.7687        [35m0.4656[0m  0.0108
     11        [36m0.4735[0m       0.7687        [35m0.4655[0m  0.0111
     12        [36m0.4714[0m       0.7687        [35m0.4654[0m  0.0104
     13        [36m0.4696[0m       0.7612        [35m0.4654[0m  0.0092
     14        [36m0.4679[0m       0.7612        [35m0.4653[0m  0.0090
     15        [36m0.4663[0m       0.7612        [35m0.4653[0m  0.0095
     16        [36m0.4648[0m       0.7687        [35m0.4653[0m  0.0099
     17        [36m0.4634[0m       0.7687        [35m0.4652[0m  0.0084
     18        [36m0.4620[0m       0.7687        [35m0.4652[0m  0.0081
     19        [36m0.4608[0m       0.7687        [35m0.4651[0m  0.0084
     20        [36m0.4596[0m       0.7761        [35m0.4651[0m  0.0089
     21        [36m0.4586[0m       0.7761        [35m0.4651[0m  0.0088
     22        [36m0.457

     10        [36m0.5126[0m       0.7970        0.4834  0.0082
     11        [36m0.5125[0m       0.7970        0.4843  0.0088
     12        0.5126       0.7970        0.4853  0.0084
     13        0.5130       0.7970        0.4864  0.0090
     14        0.5135       0.7970        0.4875  0.0092
     15        0.5142       0.7970        0.4888  0.0089
     16        0.5150       0.7970        0.4901  0.0083
     17        0.5159       0.7895        0.4914  0.0088
     18        0.5169       0.7895        0.4928  0.0079
     19        0.5180       0.7895        0.4942  0.0090
     20        0.5192       0.7895        0.4956  0.0083
     21        0.5204       0.7895        0.4970  0.0084
     22        0.5216       0.7895        0.4984  0.0081
     23        0.5229       0.7895        0.4998  0.0093
     24        0.5242       0.7895        0.5012  0.0084
     25        0.5255       0.7895        0.5025  0.0083
     26        0.5268       0.7895        0.5039  0.0087
     27      

     15        0.5478       0.8060        0.5106  0.0213
     16        0.5491       0.8060        0.5117  0.0123
     17        0.5501       0.8060        0.5124  0.0151
     18        0.5510       0.8060        0.5130  0.0125
     19        0.5517       0.8060        0.5134  0.0121
     20        0.5523       0.8060        0.5137  0.0125
     21        0.5527       0.8060        0.5139  0.0134
     22        0.5531       0.8060        0.5141  0.0126
     23        0.5533       0.8060        0.5142  0.0128
     24        0.5536       0.8060        0.5143  0.0122
     25        0.5537       0.8060        0.5143  0.0123
     26        0.5539       0.8060        0.5144  0.0126
     27        0.5540       0.8060        0.5145  0.0142
     28        0.5541       0.8060        0.5145  0.0125
     29        0.5542       0.8060        0.5146  0.0119
     30        0.5543       0.8060        0.5146  0.0123
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  --------

     16        [36m0.4610[0m       0.7687        0.4671  0.0074
     17        [36m0.4603[0m       0.7761        0.4672  0.0088
     18        [36m0.4597[0m       0.7761        0.4671  0.0079
     19        [36m0.4592[0m       0.7761        0.4670  0.0092
     20        [36m0.4587[0m       0.7761        0.4667  0.0084
     21        [36m0.4583[0m       0.7761        0.4664  0.0079
     22        [36m0.4580[0m       0.7761        0.4661  0.0095
     23        [36m0.4577[0m       0.7761        0.4658  0.0073
     24        [36m0.4574[0m       0.7761        0.4655  0.0070
     25        [36m0.4572[0m       0.7761        0.4652  0.0071
     26        [36m0.4570[0m       0.7761        0.4649  0.0074
     27        [36m0.4569[0m       0.7761        0.4647  0.0075
     28        [36m0.4568[0m       0.7761        0.4645  0.0079
     29        [36m0.4568[0m       0.7761        0.4643  0.0079
     30        [36m0.4568[0m       0.7761        0.4641  0.0097
  epoch   

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.8042[0m       [32m0.6917[0m        [35m0.5357[0m  0.0092
      2        [36m0.5514[0m       [32m0.7820[0m        [35m0.4739[0m  0.0092
      3        0.5525       0.7669        [35m0.4724[0m  0.0106
      4        [36m0.5373[0m       0.7594        [35m0.4660[0m  0.0103
      5        [36m0.5293[0m       [32m0.7895[0m        [35m0.4632[0m  0.0106
      6        [36m0.5218[0m       [32m0.7970[0m        [35m0.4617[0m  0.0093
      7        [36m0.5156[0m       0.7970        [35m0.4609[0m  0.0101
      8        [36m0.5105[0m       0.7970        [35m0.4603[0m  0.0113
      9        [36m0.5065[0m       0.7970        [35m0.4600[0m  0.0103
     10        [36m0.5032[0m       0.7895        [35m0.4597[0m  0.0097
     11        [36m0.5006[0m       0.7895        [35m0.4596[0m  0.0093
     12        [36m0.4984[0m    

     22        [36m0.5107[0m       0.7970        [35m0.4872[0m  0.0398
     23        [36m0.5102[0m       0.7970        [35m0.4870[0m  0.0461
     24        [36m0.5098[0m       0.7970        [35m0.4868[0m  0.0404
     25        [36m0.5095[0m       0.7970        [35m0.4866[0m  0.0438
     26        [36m0.5091[0m       0.7970        [35m0.4865[0m  0.0439
     27        [36m0.5088[0m       0.7970        [35m0.4864[0m  0.0433
     28        [36m0.5086[0m       0.7970        [35m0.4863[0m  0.0429
     29        [36m0.5084[0m       0.7970        [35m0.4862[0m  0.0406
     30        [36m0.5082[0m       0.7970        [35m0.4862[0m  0.0403
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6676[0m       [32m0.7687[0m        [35m0.6068[0m  0.0410
      2        [36m0.6049[0m       [32m0.7910[0m        [35m0.5591[0m  0.0403
      3        [36m0.5741[0m       0.7910

     14        0.5690       0.7836        0.5332  0.0188
     15        0.5695       0.7836        0.5336  0.0185
     16        0.5698       0.7836        0.5338  0.0186
     17        0.5700       0.7836        0.5340  0.0206
     18        0.5701       0.7836        0.5341  0.0201
     19        0.5702       0.7761        0.5342  0.0213
     20        0.5703       0.7761        0.5344  0.0186
     21        0.5704       0.7761        0.5344  0.0195
     22        0.5704       0.7761        0.5345  0.0185
     23        0.5704       0.7687        0.5345  0.0225
     24        0.5704       0.7687        0.5346  0.0182
     25        0.5704       0.7687        0.5346  0.0204
     26        0.5704       0.7687        0.5346  0.0192
     27        0.5705       0.7687        0.5345  0.0198
     28        0.5705       0.7687        0.5346  0.0183
     29        0.5705       0.7687        0.5346  0.0209
     30        0.5705       0.7687        0.5346  0.0177
  epoch    train_loss    valid_

     16        0.5378       [32m0.7239[0m        [35m0.5543[0m  0.0105
     17        0.5386       [32m0.7388[0m        [35m0.5520[0m  0.0116
     18        0.5397       0.7313        [35m0.5504[0m  0.0126
     19        0.5409       0.7388        [35m0.5496[0m  0.0117
     20        0.5422       [32m0.7463[0m        [35m0.5495[0m  0.0126
     21        0.5437       [32m0.7612[0m        [35m0.5495[0m  0.0114
     22        0.5451       0.7612        0.5499  0.0131
     23        0.5466       0.7612        0.5505  0.0115
     24        0.5480       0.7612        0.5511  0.0114
     25        0.5494       0.7612        0.5518  0.0116
     26        0.5507       0.7612        0.5527  0.0133
     27        0.5520       0.7537        0.5535  0.0120
     28        0.5533       0.7537        0.5543  0.0138
     29        0.5544       0.7537        0.5550  0.0115
     30        0.5556       0.7463        0.5558  0.0127
  epoch    train_loss    valid_acc    valid_loss     du

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.5988[0m       [32m0.7970[0m        [35m0.5009[0m  0.0096
      2        [36m0.5403[0m       0.7820        [35m0.4926[0m  0.0107
      3        [36m0.5294[0m       0.7820        [35m0.4910[0m  0.0103
      4        [36m0.5255[0m       0.7744        0.4916  0.0097
      5        [36m0.5244[0m       0.7744        0.4935  0.0102
      6        0.5249       0.7820        0.4960  0.0109
      7        0.5264       0.7820        0.4991  0.0114
      8        0.5286       0.7820        0.5026  0.0113
      9        0.5312       0.7820        0.5062  0.0110
     10        0.5342       0.7895        0.5100  0.0104
     11        0.5373       0.7895        0.5137  0.0108
     12        0.5404       0.7895        0.5172  0.0118
     13        0.5435       0.7895        0.5206  0.0102
     14        0.5465       0.7895        0.5238  0.0101
     15

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6667[0m       [32m0.6667[0m        [35m0.6213[0m  0.0176
      2        [36m0.6062[0m       [32m0.7041[0m        [35m0.5925[0m  0.0168
      3        [36m0.5818[0m       [32m0.7109[0m        [35m0.5783[0m  0.0170
      4        [36m0.5682[0m       0.7075        [35m0.5699[0m  0.0172
      5        [36m0.5595[0m       [32m0.7143[0m        [35m0.5644[0m  0.0172
      6        [36m0.5535[0m       0.7143        [35m0.5605[0m  0.0171
      7        [36m0.5490[0m       0.7143        [35m0.5575[0m  0.0172
      8        [36m0.5456[0m       0.7143        [35m0.5552[0m  0.0174
      9        [36m0.5428[0m       0.7109        [35m0.5534[0m  0.0165
     10        [36m0.5406[0m       0.7075        [35m0.5518[0m  0.0191
     11        [36m0.5388[0m       0.7075        [35m0.5506[0m  0.0183
     12        [36m0.537

     13        [36m0.5369[0m       0.7041        [35m0.5548[0m  0.0163
     14        [36m0.5354[0m       0.7041        [35m0.5536[0m  0.0180
     15        [36m0.5341[0m       [32m0.7075[0m        [35m0.5525[0m  0.0179
     16        [36m0.5329[0m       [32m0.7109[0m        [35m0.5516[0m  0.0172
     17        [36m0.5319[0m       [32m0.7143[0m        [35m0.5508[0m  0.0164
     18        [36m0.5311[0m       0.7143        [35m0.5501[0m  0.0172
     19        [36m0.5303[0m       0.7143        [35m0.5494[0m  0.0180
     20        [36m0.5296[0m       0.7143        [35m0.5489[0m  0.0162
     21        [36m0.5290[0m       [32m0.7177[0m        [35m0.5483[0m  0.0185
     22        [36m0.5285[0m       0.7177        [35m0.5479[0m  0.0177
     23        [36m0.5280[0m       0.7143        [35m0.5475[0m  0.0167
     24        [36m0.5276[0m       0.7143        [35m0.5471[0m  0.0177
     25        [36m0.5272[0m       0.7109        [35m0.5468[

     30        0.5857       0.7177        0.5916  0.0172
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6301[0m       [32m0.7474[0m        [35m0.5473[0m  0.0164
      2        [36m0.5629[0m       0.7474        [35m0.5245[0m  0.0170
      3        [36m0.5519[0m       [32m0.7645[0m        [35m0.5156[0m  0.0183
      4        [36m0.5465[0m       [32m0.7816[0m        [35m0.5117[0m  0.0157
      5        [36m0.5438[0m       [32m0.7918[0m        [35m0.5102[0m  0.0170
      6        [36m0.5425[0m       0.7918        [35m0.5102[0m  0.0166
      7        [36m0.5424[0m       0.7918        0.5111  0.0162
      8        0.5429       0.7850        0.5127  0.0182
      9        0.5440       0.7884        0.5148  0.0180
     10        0.5456       0.7850        0.5174  0.0186
     11        0.5475       0.7816        0.5202  0.0161
     12        0.5496       0.7850        0.5234 