In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from utils import *
import statistics as sta

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
class MLP(nn.Module):
    
    def load_model(self, path="data/trained_model.pth"):

        print("Loading model from "+path)
        self.load_state_dict(torch.load(path))


    def save_model(self, path="data/trained_model.pth"):

        print("Saving model to "+path)
        torch.save(self.state_dict(), path)

    def __init__(self, hidden_layer = 1, initializtion = True, cancer_type = True):
        super().__init__()
        
        self.hidden_layer = hidden_layer
        self.initializtion = initializtion
        self.cancer_type = cancer_type
        
        self.layer_sga_emb = nn.Embedding(
            num_embeddings=19781+1,
            embedding_dim=512,
            padding_idx=0)
        
        self.layer_can_emb = nn.Embedding(
            num_embeddings=16+1,
            embedding_dim=512,
            padding_idx=0)
        
        #self.layer_dropout_1 = nn.Dropout(p=0.5)

        self.layer_w_1 = nn.Linear(
            in_features=512,
            out_features=1024,
            bias=True)
        
        self.layer_w_0 = nn.Linear(
            in_features=512,
            out_features=2207,
            bias=True)
        
       #self.layer_dropout_h_2 = nn.Dropout(p=0.5)

        self.layer_w_h_2 = nn.Linear(
            in_features=1024,
            out_features=1024,
            bias=True)
        
        #self.layer_dropout_h_3 = nn.Dropout(p=0.5)

        self.layer_w_h_3 = nn.Linear(
            in_features=1024,
            out_features=1024,
            bias=True)

        #self.layer_dropout_2 = nn.Dropout(p=0.5)

        self.layer_w_2 = nn.Linear(
            in_features=1024,
            out_features=2207,
            bias=True)
        
        self.optimizer = optim.Adam(
            self.parameters(),
            lr=1e-4,
            weight_decay=1e-5)
        
        if self.initializtion:
            gene_emb_pretrain = np.load(os.path.join("data", "gene_emb_pretrain.npy"))
            self.layer_sga_emb.weight.data.copy_(torch.from_numpy(gene_emb_pretrain))


    def forward(self, sga_index, can_index):
    
        # gene embedings
        E_t = self.layer_sga_emb(sga_index)
        emb_sga = torch.sum(E_t, dim=1)
        emb_sga = emb_sga.view(-1, 512)
        
        if self.cancer_type:
            # cancer type embedding
            emb_can = self.layer_can_emb(can_index)
            emb_can = emb_can.view(-1, 512)
            
            emb_tmr = emb_can+emb_sga
        else:
            emb_tmr = emb_sga
        
        if self.hidden_layer > 0:
            emb_tmr_relu = F.relu(emb_tmr)
            hid_tmr = self.layer_w_1(emb_tmr_relu)
        
            if self.hidden_layer > 1:
                hid_tmr_relu_2 = F.relu(hid_tmr)
                hid_tmr = self.layer_w_h_2(hid_tmr_relu_2)

            if self.hidden_layer > 2:
                hid_tmr_relu_3 = F.relu(hid_tmr)
                hid_tmr = self.layer_w_h_3(hid_tmr_relu_3)
            
            hid_tmr_relu = F.relu(hid_tmr)
            preds = F.sigmoid(self.layer_w_2(hid_tmr_relu))

        else:
            emb_tmr_relu = F.relu(emb_tmr)
            hid_tmr = self.layer_w_0(emb_tmr_relu)
            preds = F.sigmoid(hid_tmr)

        return preds
    
    def train(self, train_set, test_set,
        batch_size=None, test_batch_size=None,
        max_iter=None, max_fscore=None,
        test_inc_size=None):

        for iter_train in range(0, max_iter+1, batch_size):
          
            batch_set = get_minibatch(train_set, iter_train, batch_size, batch_type="train")
            preds = self.forward(batch_set["sga"].to(device), batch_set["can"].to(device))
            labels = batch_set["deg"].to(device)

            self.optimizer.zero_grad()
            loss = -torch.log( 1e-4 +1-torch.abs(preds-labels) ).mean()
            loss.backward()
            self.optimizer.step()

            if test_inc_size and (iter_train % test_inc_size == 0):
                labels, preds = self.test(test_set, test_batch_size)
                precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
                print("[%d,%d], f1_score: %.3f, acc: %.3f"% (iter_train//len(train_set["can"]),
                                                             iter_train%len(train_set["can"]), f1score, accuracy))

                if f1score >= max_fscore:
                     break

        #self.save_model(os.path.join(self.output_dir, "trained_model.pth"))


    def test(self, test_set, test_batch_size):

        labels, preds = [], []
        
        for iter_test in range(0, len(test_set["can"]), test_batch_size):
            
            batch_set = get_minibatch(test_set, iter_test, test_batch_size, batch_type="test")
            batch_preds = self.forward(batch_set["sga"].to(device), batch_set["can"].to(device))
            batch_labels = batch_set["deg"].to(device)

            labels.append(batch_labels.data.to(torch.device("cpu")).numpy())
            preds.append(batch_preds.data.to(torch.device("cpu")).numpy())

        labels = np.concatenate(labels,axis=0)
        preds = np.concatenate(preds,axis=0)

        return labels, preds

In [17]:
# Load data
dataset = load_dataset(input_dir="data", deg_shuffle=False)
train_set, test_set = split_dataset(dataset, ratio=0.66)

### 1 layer MLP

In [20]:
precision_h1, recall_h1, f1score_h1, accuracy_h1 = [], [], [], []

for i in range(5):
    # Init model with single hidden layer
    model = MLP().to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*40,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_h1.append(precision)
    recall_h1.append(recall)
    f1score_h1.append(f1score)
    accuracy_h1.append(accuracy)

[0,0], f1_score: 0.386, acc: 0.520
[0,256], f1_score: 0.383, acc: 0.671
[0,512], f1_score: 0.396, acc: 0.661
[0,768], f1_score: 0.409, acc: 0.675
[0,1024], f1_score: 0.425, acc: 0.682
[0,1280], f1_score: 0.401, acc: 0.694
[0,1536], f1_score: 0.422, acc: 0.700
[0,1792], f1_score: 0.393, acc: 0.710
[0,2048], f1_score: 0.447, acc: 0.707
[0,2304], f1_score: 0.400, acc: 0.718
[0,2560], f1_score: 0.456, acc: 0.715
[0,2816], f1_score: 0.420, acc: 0.723
[1,124], f1_score: 0.450, acc: 0.724
[1,380], f1_score: 0.452, acc: 0.725
[1,636], f1_score: 0.436, acc: 0.730
[1,892], f1_score: 0.470, acc: 0.730
[1,1148], f1_score: 0.455, acc: 0.733
[1,1404], f1_score: 0.450, acc: 0.736
[1,1660], f1_score: 0.457, acc: 0.737
[1,1916], f1_score: 0.489, acc: 0.736
[1,2172], f1_score: 0.427, acc: 0.739
[1,2428], f1_score: 0.485, acc: 0.741
[1,2684], f1_score: 0.484, acc: 0.743
[1,2940], f1_score: 0.476, acc: 0.746
[2,248], f1_score: 0.473, acc: 0.745
[2,504], f1_score: 0.493, acc: 0.746
[2,760], f1_score: 0.500

[18,2232], f1_score: 0.620, acc: 0.779
[18,2488], f1_score: 0.609, acc: 0.776
[18,2744], f1_score: 0.609, acc: 0.780
[19,52], f1_score: 0.581, acc: 0.780
[19,308], f1_score: 0.627, acc: 0.770
[19,564], f1_score: 0.550, acc: 0.775
[19,820], f1_score: 0.625, acc: 0.769
[19,1076], f1_score: 0.587, acc: 0.778
[19,1332], f1_score: 0.597, acc: 0.778
[19,1588], f1_score: 0.588, acc: 0.777
[19,1844], f1_score: 0.617, acc: 0.781
[19,2100], f1_score: 0.616, acc: 0.781
[19,2356], f1_score: 0.581, acc: 0.778
[19,2612], f1_score: 0.625, acc: 0.775
[19,2868], f1_score: 0.584, acc: 0.780
[20,176], f1_score: 0.608, acc: 0.782
[20,432], f1_score: 0.622, acc: 0.780
[20,688], f1_score: 0.606, acc: 0.779
[20,944], f1_score: 0.612, acc: 0.777
[20,1200], f1_score: 0.610, acc: 0.780
[20,1456], f1_score: 0.567, acc: 0.778
[20,1712], f1_score: 0.628, acc: 0.779
[20,1968], f1_score: 0.588, acc: 0.779
[20,2224], f1_score: 0.626, acc: 0.777
[20,2480], f1_score: 0.581, acc: 0.780
[20,2736], f1_score: 0.628, acc: 0

[37,748], f1_score: 0.613, acc: 0.773
[37,1004], f1_score: 0.602, acc: 0.779
[37,1260], f1_score: 0.606, acc: 0.778
[37,1516], f1_score: 0.615, acc: 0.775
[37,1772], f1_score: 0.577, acc: 0.776
[37,2028], f1_score: 0.622, acc: 0.772
[37,2284], f1_score: 0.603, acc: 0.779
[37,2540], f1_score: 0.615, acc: 0.777
[37,2796], f1_score: 0.610, acc: 0.775
[38,104], f1_score: 0.599, acc: 0.778
[38,360], f1_score: 0.612, acc: 0.775
[38,616], f1_score: 0.610, acc: 0.779
[38,872], f1_score: 0.612, acc: 0.778
[38,1128], f1_score: 0.601, acc: 0.779
[38,1384], f1_score: 0.611, acc: 0.779
[38,1640], f1_score: 0.607, acc: 0.776
[38,1896], f1_score: 0.585, acc: 0.774
[38,2152], f1_score: 0.622, acc: 0.773
[38,2408], f1_score: 0.601, acc: 0.779
[38,2664], f1_score: 0.606, acc: 0.778
[38,2920], f1_score: 0.617, acc: 0.772
[39,228], f1_score: 0.585, acc: 0.775
[39,484], f1_score: 0.617, acc: 0.776
[39,740], f1_score: 0.612, acc: 0.773
[39,996], f1_score: 0.606, acc: 0.779
[39,1252], f1_score: 0.610, acc: 0

[14,200], f1_score: 0.627, acc: 0.775
[14,456], f1_score: 0.570, acc: 0.780
[14,712], f1_score: 0.588, acc: 0.781
[14,968], f1_score: 0.621, acc: 0.779
[14,1224], f1_score: 0.614, acc: 0.782
[14,1480], f1_score: 0.602, acc: 0.781
[14,1736], f1_score: 0.600, acc: 0.783
[14,1992], f1_score: 0.627, acc: 0.755
[14,2248], f1_score: 0.572, acc: 0.779
[14,2504], f1_score: 0.588, acc: 0.774
[14,2760], f1_score: 0.616, acc: 0.779
[15,68], f1_score: 0.603, acc: 0.778
[15,324], f1_score: 0.612, acc: 0.780
[15,580], f1_score: 0.630, acc: 0.773
[15,836], f1_score: 0.573, acc: 0.782
[15,1092], f1_score: 0.603, acc: 0.780
[15,1348], f1_score: 0.586, acc: 0.781
[15,1604], f1_score: 0.627, acc: 0.775
[15,1860], f1_score: 0.557, acc: 0.778
[15,2116], f1_score: 0.612, acc: 0.776
[15,2372], f1_score: 0.621, acc: 0.769
[15,2628], f1_score: 0.589, acc: 0.779
[15,2884], f1_score: 0.576, acc: 0.776
[16,192], f1_score: 0.621, acc: 0.772
[16,448], f1_score: 0.567, acc: 0.779
[16,704], f1_score: 0.629, acc: 0.76

[32,1664], f1_score: 0.560, acc: 0.776
[32,1920], f1_score: 0.628, acc: 0.771
[32,2176], f1_score: 0.602, acc: 0.777
[32,2432], f1_score: 0.611, acc: 0.780
[32,2688], f1_score: 0.616, acc: 0.774
[32,2944], f1_score: 0.607, acc: 0.778
[33,252], f1_score: 0.613, acc: 0.778
[33,508], f1_score: 0.605, acc: 0.779
[33,764], f1_score: 0.623, acc: 0.772
[33,1020], f1_score: 0.578, acc: 0.776
[33,1276], f1_score: 0.625, acc: 0.770
[33,1532], f1_score: 0.593, acc: 0.774
[33,1788], f1_score: 0.593, acc: 0.781
[33,2044], f1_score: 0.626, acc: 0.775
[33,2300], f1_score: 0.601, acc: 0.778
[33,2556], f1_score: 0.614, acc: 0.781
[33,2812], f1_score: 0.613, acc: 0.774
[34,120], f1_score: 0.597, acc: 0.777
[34,376], f1_score: 0.620, acc: 0.777
[34,632], f1_score: 0.607, acc: 0.779
[34,888], f1_score: 0.616, acc: 0.778
[34,1144], f1_score: 0.583, acc: 0.777
[34,1400], f1_score: 0.629, acc: 0.770
[34,1656], f1_score: 0.562, acc: 0.772
[34,1912], f1_score: 0.621, acc: 0.774
[34,2168], f1_score: 0.614, acc:

[9,1116], f1_score: 0.556, acc: 0.775
[9,1372], f1_score: 0.548, acc: 0.778
[9,1628], f1_score: 0.579, acc: 0.782
[9,1884], f1_score: 0.607, acc: 0.781
[9,2140], f1_score: 0.553, acc: 0.778
[9,2396], f1_score: 0.605, acc: 0.781
[9,2652], f1_score: 0.604, acc: 0.781
[9,2908], f1_score: 0.566, acc: 0.780
[10,216], f1_score: 0.588, acc: 0.782
[10,472], f1_score: 0.594, acc: 0.777
[10,728], f1_score: 0.608, acc: 0.781
[10,984], f1_score: 0.617, acc: 0.776
[10,1240], f1_score: 0.598, acc: 0.781
[10,1496], f1_score: 0.599, acc: 0.782
[10,1752], f1_score: 0.560, acc: 0.779
[10,2008], f1_score: 0.609, acc: 0.782
[10,2264], f1_score: 0.602, acc: 0.782
[10,2520], f1_score: 0.594, acc: 0.783
[10,2776], f1_score: 0.572, acc: 0.782
[11,84], f1_score: 0.619, acc: 0.777
[11,340], f1_score: 0.592, acc: 0.783
[11,596], f1_score: 0.607, acc: 0.781
[11,852], f1_score: 0.594, acc: 0.782
[11,1108], f1_score: 0.608, acc: 0.780
[11,1364], f1_score: 0.592, acc: 0.784
[11,1620], f1_score: 0.626, acc: 0.774
[11

[27,2580], f1_score: 0.618, acc: 0.777
[27,2836], f1_score: 0.609, acc: 0.777
[28,144], f1_score: 0.589, acc: 0.778
[28,400], f1_score: 0.619, acc: 0.774
[28,656], f1_score: 0.598, acc: 0.774
[28,912], f1_score: 0.592, acc: 0.771
[28,1168], f1_score: 0.624, acc: 0.772
[28,1424], f1_score: 0.565, acc: 0.776
[28,1680], f1_score: 0.612, acc: 0.776
[28,1936], f1_score: 0.609, acc: 0.776
[28,2192], f1_score: 0.600, acc: 0.778
[28,2448], f1_score: 0.617, acc: 0.778
[28,2704], f1_score: 0.608, acc: 0.779
[29,12], f1_score: 0.617, acc: 0.774
[29,268], f1_score: 0.587, acc: 0.779
[29,524], f1_score: 0.621, acc: 0.772
[29,780], f1_score: 0.573, acc: 0.769
[29,1036], f1_score: 0.610, acc: 0.767
[29,1292], f1_score: 0.604, acc: 0.773
[29,1548], f1_score: 0.585, acc: 0.779
[29,1804], f1_score: 0.602, acc: 0.779
[29,2060], f1_score: 0.618, acc: 0.763
[29,2316], f1_score: 0.583, acc: 0.778
[29,2572], f1_score: 0.622, acc: 0.776
[29,2828], f1_score: 0.596, acc: 0.778
[30,136], f1_score: 0.614, acc: 0.

[4,1520], f1_score: 0.531, acc: 0.769
[4,1776], f1_score: 0.530, acc: 0.768
[4,2032], f1_score: 0.544, acc: 0.772
[4,2288], f1_score: 0.549, acc: 0.773
[4,2544], f1_score: 0.571, acc: 0.774
[4,2800], f1_score: 0.543, acc: 0.773
[5,108], f1_score: 0.580, acc: 0.772
[5,364], f1_score: 0.553, acc: 0.773
[5,620], f1_score: 0.559, acc: 0.775
[5,876], f1_score: 0.552, acc: 0.774
[5,1132], f1_score: 0.561, acc: 0.772
[5,1388], f1_score: 0.573, acc: 0.776
[5,1644], f1_score: 0.576, acc: 0.775
[5,1900], f1_score: 0.604, acc: 0.768
[5,2156], f1_score: 0.588, acc: 0.777
[5,2412], f1_score: 0.600, acc: 0.775
[5,2668], f1_score: 0.592, acc: 0.777
[5,2924], f1_score: 0.578, acc: 0.777
[6,232], f1_score: 0.574, acc: 0.779
[6,488], f1_score: 0.575, acc: 0.777
[6,744], f1_score: 0.581, acc: 0.779
[6,1000], f1_score: 0.594, acc: 0.777
[6,1256], f1_score: 0.572, acc: 0.777
[6,1512], f1_score: 0.587, acc: 0.779
[6,1768], f1_score: 0.573, acc: 0.778
[6,2024], f1_score: 0.564, acc: 0.777
[6,2280], f1_score:

[23,292], f1_score: 0.624, acc: 0.777
[23,548], f1_score: 0.598, acc: 0.777
[23,804], f1_score: 0.606, acc: 0.776
[23,1060], f1_score: 0.620, acc: 0.775
[23,1316], f1_score: 0.598, acc: 0.781
[23,1572], f1_score: 0.589, acc: 0.780
[23,1828], f1_score: 0.624, acc: 0.776
[23,2084], f1_score: 0.592, acc: 0.781
[23,2340], f1_score: 0.626, acc: 0.779
[23,2596], f1_score: 0.581, acc: 0.779
[23,2852], f1_score: 0.627, acc: 0.776
[24,160], f1_score: 0.566, acc: 0.778
[24,416], f1_score: 0.626, acc: 0.773
[24,672], f1_score: 0.548, acc: 0.771
[24,928], f1_score: 0.622, acc: 0.763
[24,1184], f1_score: 0.597, acc: 0.778
[24,1440], f1_score: 0.606, acc: 0.782
[24,1696], f1_score: 0.580, acc: 0.778
[24,1952], f1_score: 0.622, acc: 0.770
[24,2208], f1_score: 0.591, acc: 0.778
[24,2464], f1_score: 0.617, acc: 0.778
[24,2720], f1_score: 0.611, acc: 0.779
[25,28], f1_score: 0.595, acc: 0.775
[25,284], f1_score: 0.612, acc: 0.779
[25,540], f1_score: 0.617, acc: 0.777
[25,796], f1_score: 0.572, acc: 0.77

[41,1756], f1_score: 0.613, acc: 0.774
[41,2012], f1_score: 0.579, acc: 0.777
Evaluating...
prec=0.691, recall=0.499, F1=0.579, acc=0.777
[0,0], f1_score: 0.382, acc: 0.522
[0,256], f1_score: 0.383, acc: 0.670
[0,512], f1_score: 0.404, acc: 0.662
[0,768], f1_score: 0.406, acc: 0.677
[0,1024], f1_score: 0.426, acc: 0.682
[0,1280], f1_score: 0.400, acc: 0.695
[0,1536], f1_score: 0.426, acc: 0.699
[0,1792], f1_score: 0.393, acc: 0.709
[0,2048], f1_score: 0.451, acc: 0.706
[0,2304], f1_score: 0.399, acc: 0.717
[0,2560], f1_score: 0.454, acc: 0.714
[0,2816], f1_score: 0.420, acc: 0.722
[1,124], f1_score: 0.449, acc: 0.723
[1,380], f1_score: 0.451, acc: 0.724
[1,636], f1_score: 0.441, acc: 0.729
[1,892], f1_score: 0.466, acc: 0.727
[1,1148], f1_score: 0.448, acc: 0.732
[1,1404], f1_score: 0.451, acc: 0.735
[1,1660], f1_score: 0.457, acc: 0.736
[1,1916], f1_score: 0.486, acc: 0.735
[1,2172], f1_score: 0.423, acc: 0.738
[1,2428], f1_score: 0.487, acc: 0.739
[1,2684], f1_score: 0.487, acc: 0.74

[18,1208], f1_score: 0.622, acc: 0.779
[18,1464], f1_score: 0.555, acc: 0.779
[18,1720], f1_score: 0.625, acc: 0.778
[18,1976], f1_score: 0.583, acc: 0.781
[18,2232], f1_score: 0.619, acc: 0.780
[18,2488], f1_score: 0.611, acc: 0.779
[18,2744], f1_score: 0.615, acc: 0.780
[19,52], f1_score: 0.579, acc: 0.779
[19,308], f1_score: 0.625, acc: 0.775
[19,564], f1_score: 0.551, acc: 0.775
[19,820], f1_score: 0.627, acc: 0.768
[19,1076], f1_score: 0.584, acc: 0.781
[19,1332], f1_score: 0.603, acc: 0.781
[19,1588], f1_score: 0.598, acc: 0.777
[19,1844], f1_score: 0.595, acc: 0.778
[19,2100], f1_score: 0.621, acc: 0.778
[19,2356], f1_score: 0.580, acc: 0.778
[19,2612], f1_score: 0.624, acc: 0.776
[19,2868], f1_score: 0.585, acc: 0.781
[20,176], f1_score: 0.613, acc: 0.779
[20,432], f1_score: 0.610, acc: 0.781
[20,688], f1_score: 0.604, acc: 0.781
[20,944], f1_score: 0.615, acc: 0.778
[20,1200], f1_score: 0.614, acc: 0.781
[20,1456], f1_score: 0.557, acc: 0.777
[20,1712], f1_score: 0.630, acc: 0

[36,2672], f1_score: 0.609, acc: 0.777
[36,2928], f1_score: 0.602, acc: 0.776
[37,236], f1_score: 0.606, acc: 0.770
[37,492], f1_score: 0.604, acc: 0.778
[37,748], f1_score: 0.611, acc: 0.773
[37,1004], f1_score: 0.608, acc: 0.779
[37,1260], f1_score: 0.596, acc: 0.777
[37,1516], f1_score: 0.620, acc: 0.775
[37,1772], f1_score: 0.589, acc: 0.777
[37,2028], f1_score: 0.612, acc: 0.777
[37,2284], f1_score: 0.615, acc: 0.775
[37,2540], f1_score: 0.602, acc: 0.777
[37,2796], f1_score: 0.615, acc: 0.772
[38,104], f1_score: 0.584, acc: 0.776
[38,360], f1_score: 0.618, acc: 0.770
[38,616], f1_score: 0.595, acc: 0.774
[38,872], f1_score: 0.612, acc: 0.773
[38,1128], f1_score: 0.602, acc: 0.779
[38,1384], f1_score: 0.608, acc: 0.777
[38,1640], f1_score: 0.615, acc: 0.777
[38,1896], f1_score: 0.586, acc: 0.776
[38,2152], f1_score: 0.621, acc: 0.770
[38,2408], f1_score: 0.603, acc: 0.779
[38,2664], f1_score: 0.609, acc: 0.778
[38,2920], f1_score: 0.617, acc: 0.772
[39,228], f1_score: 0.576, acc: 

In [21]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_h1), sta.mean(recall_h1), sta.mean(f1score_h1), sta.mean(accuracy_h1)))

prec=0.682, recall=0.513, F1=0.585, acc=0.776


### 2 layer MLP

In [22]:
precision_h2, recall_h2, f1score_h2, accuracy_h2 = [], [], [], []

for i in range(5):
    
    # Init model with single hidden layer
    model = MLP(hidden_layer = 2).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*40,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_h2.append(precision)
    recall_h2.append(recall)
    f1score_h2.append(f1score)
    accuracy_h2.append(accuracy)

[0,0], f1_score: 0.388, acc: 0.538
[0,256], f1_score: 0.374, acc: 0.694
[0,512], f1_score: 0.381, acc: 0.703
[0,768], f1_score: 0.400, acc: 0.707
[0,1024], f1_score: 0.410, acc: 0.709
[0,1280], f1_score: 0.393, acc: 0.714
[0,1536], f1_score: 0.402, acc: 0.717
[0,1792], f1_score: 0.367, acc: 0.722
[0,2048], f1_score: 0.410, acc: 0.723
[0,2304], f1_score: 0.418, acc: 0.723
[0,2560], f1_score: 0.421, acc: 0.726
[0,2816], f1_score: 0.369, acc: 0.728
[1,124], f1_score: 0.447, acc: 0.727
[1,380], f1_score: 0.410, acc: 0.731
[1,636], f1_score: 0.449, acc: 0.733
[1,892], f1_score: 0.464, acc: 0.733
[1,1148], f1_score: 0.460, acc: 0.738
[1,1404], f1_score: 0.468, acc: 0.738
[1,1660], f1_score: 0.429, acc: 0.741
[1,1916], f1_score: 0.496, acc: 0.740
[1,2172], f1_score: 0.445, acc: 0.744
[1,2428], f1_score: 0.473, acc: 0.747
[1,2684], f1_score: 0.485, acc: 0.749
[1,2940], f1_score: 0.464, acc: 0.751
[2,248], f1_score: 0.444, acc: 0.747
[2,504], f1_score: 0.484, acc: 0.752
[2,760], f1_score: 0.491

[18,2232], f1_score: 0.619, acc: 0.781
[18,2488], f1_score: 0.610, acc: 0.782
[18,2744], f1_score: 0.614, acc: 0.779
[19,52], f1_score: 0.600, acc: 0.778
[19,308], f1_score: 0.620, acc: 0.775
[19,564], f1_score: 0.584, acc: 0.775
[19,820], f1_score: 0.625, acc: 0.769
[19,1076], f1_score: 0.591, acc: 0.779
[19,1332], f1_score: 0.586, acc: 0.780
[19,1588], f1_score: 0.623, acc: 0.778
[19,1844], f1_score: 0.612, acc: 0.779
[19,2100], f1_score: 0.614, acc: 0.779
[19,2356], f1_score: 0.609, acc: 0.781
[19,2612], f1_score: 0.631, acc: 0.772
[19,2868], f1_score: 0.561, acc: 0.774
[20,176], f1_score: 0.628, acc: 0.775
[20,432], f1_score: 0.595, acc: 0.778
[20,688], f1_score: 0.606, acc: 0.777
[20,944], f1_score: 0.613, acc: 0.776
[20,1200], f1_score: 0.620, acc: 0.774
[20,1456], f1_score: 0.568, acc: 0.779
[20,1712], f1_score: 0.613, acc: 0.779
[20,1968], f1_score: 0.613, acc: 0.770
[20,2224], f1_score: 0.610, acc: 0.779
[20,2480], f1_score: 0.591, acc: 0.777
[20,2736], f1_score: 0.626, acc: 0

[37,748], f1_score: 0.615, acc: 0.761
[37,1004], f1_score: 0.594, acc: 0.772
[37,1260], f1_score: 0.611, acc: 0.765
[37,1516], f1_score: 0.592, acc: 0.769
[37,1772], f1_score: 0.573, acc: 0.768
[37,2028], f1_score: 0.612, acc: 0.764
[37,2284], f1_score: 0.600, acc: 0.764
[37,2540], f1_score: 0.614, acc: 0.772
[37,2796], f1_score: 0.599, acc: 0.769
[38,104], f1_score: 0.590, acc: 0.771
[38,360], f1_score: 0.606, acc: 0.765
[38,616], f1_score: 0.593, acc: 0.770
[38,872], f1_score: 0.610, acc: 0.764
[38,1128], f1_score: 0.586, acc: 0.769
[38,1384], f1_score: 0.619, acc: 0.762
[38,1640], f1_score: 0.571, acc: 0.768
[38,1896], f1_score: 0.593, acc: 0.757
[38,2152], f1_score: 0.585, acc: 0.764
[38,2408], f1_score: 0.610, acc: 0.760
[38,2664], f1_score: 0.595, acc: 0.772
[38,2920], f1_score: 0.612, acc: 0.766
[39,228], f1_score: 0.587, acc: 0.770
[39,484], f1_score: 0.609, acc: 0.768
[39,740], f1_score: 0.601, acc: 0.767
[39,996], f1_score: 0.605, acc: 0.768
[39,1252], f1_score: 0.589, acc: 0

[14,200], f1_score: 0.633, acc: 0.770
[14,456], f1_score: 0.573, acc: 0.782
[14,712], f1_score: 0.617, acc: 0.777
[14,968], f1_score: 0.624, acc: 0.769
[14,1224], f1_score: 0.575, acc: 0.777
[14,1480], f1_score: 0.584, acc: 0.774
[14,1736], f1_score: 0.602, acc: 0.781
[14,1992], f1_score: 0.627, acc: 0.768
[14,2248], f1_score: 0.564, acc: 0.778
[14,2504], f1_score: 0.630, acc: 0.773
[14,2760], f1_score: 0.604, acc: 0.784
[15,68], f1_score: 0.575, acc: 0.779
[15,324], f1_score: 0.621, acc: 0.776
[15,580], f1_score: 0.623, acc: 0.782
[15,836], f1_score: 0.602, acc: 0.784
[15,1092], f1_score: 0.617, acc: 0.777
[15,1348], f1_score: 0.586, acc: 0.779
[15,1604], f1_score: 0.629, acc: 0.780
[15,1860], f1_score: 0.564, acc: 0.775
[15,2116], f1_score: 0.628, acc: 0.775
[15,2372], f1_score: 0.603, acc: 0.783
[15,2628], f1_score: 0.625, acc: 0.781
[15,2884], f1_score: 0.615, acc: 0.782
[16,192], f1_score: 0.621, acc: 0.779
[16,448], f1_score: 0.577, acc: 0.777
[16,704], f1_score: 0.625, acc: 0.76

[32,1664], f1_score: 0.559, acc: 0.770
[32,1920], f1_score: 0.625, acc: 0.765
[32,2176], f1_score: 0.596, acc: 0.772
[32,2432], f1_score: 0.614, acc: 0.773
[32,2688], f1_score: 0.610, acc: 0.768
[32,2944], f1_score: 0.592, acc: 0.775
[33,252], f1_score: 0.592, acc: 0.767
[33,508], f1_score: 0.617, acc: 0.765
[33,764], f1_score: 0.582, acc: 0.768
[33,1020], f1_score: 0.601, acc: 0.770
[33,1276], f1_score: 0.617, acc: 0.767
[33,1532], f1_score: 0.584, acc: 0.771
[33,1788], f1_score: 0.577, acc: 0.769
[33,2044], f1_score: 0.604, acc: 0.767
[33,2300], f1_score: 0.609, acc: 0.771
[33,2556], f1_score: 0.605, acc: 0.774
[33,2812], f1_score: 0.608, acc: 0.772
[34,120], f1_score: 0.616, acc: 0.770
[34,376], f1_score: 0.578, acc: 0.771
[34,632], f1_score: 0.616, acc: 0.766
[34,888], f1_score: 0.593, acc: 0.768
[34,1144], f1_score: 0.596, acc: 0.767
[34,1400], f1_score: 0.599, acc: 0.770
[34,1656], f1_score: 0.597, acc: 0.770
[34,1912], f1_score: 0.618, acc: 0.769
[34,2168], f1_score: 0.584, acc:

[9,1116], f1_score: 0.634, acc: 0.776
[9,1372], f1_score: 0.606, acc: 0.782
[9,1628], f1_score: 0.602, acc: 0.783
[9,1884], f1_score: 0.618, acc: 0.782
[9,2140], f1_score: 0.586, acc: 0.784
[9,2396], f1_score: 0.609, acc: 0.782
[9,2652], f1_score: 0.601, acc: 0.782
[9,2908], f1_score: 0.557, acc: 0.779
[10,216], f1_score: 0.583, acc: 0.782
[10,472], f1_score: 0.595, acc: 0.780
[10,728], f1_score: 0.577, acc: 0.781
[10,984], f1_score: 0.582, acc: 0.782
[10,1240], f1_score: 0.616, acc: 0.783
[10,1496], f1_score: 0.622, acc: 0.776
[10,1752], f1_score: 0.570, acc: 0.779
[10,2008], f1_score: 0.612, acc: 0.786
[10,2264], f1_score: 0.594, acc: 0.785
[10,2520], f1_score: 0.608, acc: 0.779
[10,2776], f1_score: 0.594, acc: 0.783
[11,84], f1_score: 0.631, acc: 0.761
[11,340], f1_score: 0.578, acc: 0.783
[11,596], f1_score: 0.608, acc: 0.781
[11,852], f1_score: 0.627, acc: 0.776
[11,1108], f1_score: 0.587, acc: 0.783
[11,1364], f1_score: 0.538, acc: 0.774
[11,1620], f1_score: 0.604, acc: 0.781
[11

[27,2580], f1_score: 0.600, acc: 0.769
[27,2836], f1_score: 0.618, acc: 0.773
[28,144], f1_score: 0.552, acc: 0.765
[28,400], f1_score: 0.613, acc: 0.759
[28,656], f1_score: 0.580, acc: 0.771
[28,912], f1_score: 0.606, acc: 0.774
[28,1168], f1_score: 0.617, acc: 0.773
[28,1424], f1_score: 0.598, acc: 0.772
[28,1680], f1_score: 0.594, acc: 0.773
[28,1936], f1_score: 0.623, acc: 0.768
[28,2192], f1_score: 0.596, acc: 0.776
[28,2448], f1_score: 0.620, acc: 0.778
[28,2704], f1_score: 0.610, acc: 0.774
[29,12], f1_score: 0.612, acc: 0.773
[29,268], f1_score: 0.597, acc: 0.773
[29,524], f1_score: 0.613, acc: 0.770
[29,780], f1_score: 0.577, acc: 0.772
[29,1036], f1_score: 0.621, acc: 0.758
[29,1292], f1_score: 0.591, acc: 0.774
[29,1548], f1_score: 0.610, acc: 0.777
[29,1804], f1_score: 0.601, acc: 0.774
[29,2060], f1_score: 0.612, acc: 0.767
[29,2316], f1_score: 0.605, acc: 0.770
[29,2572], f1_score: 0.602, acc: 0.768
[29,2828], f1_score: 0.610, acc: 0.771
[30,136], f1_score: 0.588, acc: 0.

[4,1520], f1_score: 0.552, acc: 0.774
[4,1776], f1_score: 0.584, acc: 0.777
[4,2032], f1_score: 0.602, acc: 0.778
[4,2288], f1_score: 0.583, acc: 0.780
[4,2544], f1_score: 0.610, acc: 0.769
[4,2800], f1_score: 0.586, acc: 0.778
[5,108], f1_score: 0.597, acc: 0.775
[5,364], f1_score: 0.547, acc: 0.774
[5,620], f1_score: 0.559, acc: 0.777
[5,876], f1_score: 0.584, acc: 0.781
[5,1132], f1_score: 0.581, acc: 0.778
[5,1388], f1_score: 0.602, acc: 0.774
[5,1644], f1_score: 0.583, acc: 0.780
[5,1900], f1_score: 0.613, acc: 0.772
[5,2156], f1_score: 0.601, acc: 0.782
[5,2412], f1_score: 0.618, acc: 0.779
[5,2668], f1_score: 0.620, acc: 0.775
[5,2924], f1_score: 0.599, acc: 0.783
[6,232], f1_score: 0.584, acc: 0.781
[6,488], f1_score: 0.612, acc: 0.780
[6,744], f1_score: 0.620, acc: 0.781
[6,1000], f1_score: 0.624, acc: 0.780
[6,1256], f1_score: 0.617, acc: 0.781
[6,1512], f1_score: 0.609, acc: 0.779
[6,1768], f1_score: 0.583, acc: 0.780
[6,2024], f1_score: 0.567, acc: 0.780
[6,2280], f1_score:

[23,292], f1_score: 0.629, acc: 0.773
[23,548], f1_score: 0.579, acc: 0.775
[23,804], f1_score: 0.624, acc: 0.765
[23,1060], f1_score: 0.585, acc: 0.773
[23,1316], f1_score: 0.621, acc: 0.776
[23,1572], f1_score: 0.598, acc: 0.779
[23,1828], f1_score: 0.603, acc: 0.776
[23,2084], f1_score: 0.624, acc: 0.768
[23,2340], f1_score: 0.592, acc: 0.777
[23,2596], f1_score: 0.630, acc: 0.775
[23,2852], f1_score: 0.587, acc: 0.773
[24,160], f1_score: 0.605, acc: 0.777
[24,416], f1_score: 0.616, acc: 0.779
[24,672], f1_score: 0.594, acc: 0.775
[24,928], f1_score: 0.620, acc: 0.771
[24,1184], f1_score: 0.601, acc: 0.776
[24,1440], f1_score: 0.579, acc: 0.773
[24,1696], f1_score: 0.628, acc: 0.764
[24,1952], f1_score: 0.570, acc: 0.769
[24,2208], f1_score: 0.623, acc: 0.765
[24,2464], f1_score: 0.571, acc: 0.771
[24,2720], f1_score: 0.622, acc: 0.763
[25,28], f1_score: 0.568, acc: 0.766
[25,284], f1_score: 0.625, acc: 0.770
[25,540], f1_score: 0.563, acc: 0.773
[25,796], f1_score: 0.628, acc: 0.77

[41,1756], f1_score: 0.577, acc: 0.768
[41,2012], f1_score: 0.602, acc: 0.762
Evaluating...
prec=0.620, recall=0.586, F1=0.602, acc=0.762
[0,0], f1_score: 0.388, acc: 0.535
[0,256], f1_score: 0.366, acc: 0.694
[0,512], f1_score: 0.379, acc: 0.702
[0,768], f1_score: 0.396, acc: 0.707
[0,1024], f1_score: 0.410, acc: 0.709
[0,1280], f1_score: 0.390, acc: 0.713
[0,1536], f1_score: 0.409, acc: 0.716
[0,1792], f1_score: 0.372, acc: 0.721
[0,2048], f1_score: 0.413, acc: 0.722
[0,2304], f1_score: 0.419, acc: 0.723
[0,2560], f1_score: 0.432, acc: 0.726
[0,2816], f1_score: 0.386, acc: 0.728
[1,124], f1_score: 0.448, acc: 0.727
[1,380], f1_score: 0.423, acc: 0.732
[1,636], f1_score: 0.455, acc: 0.733
[1,892], f1_score: 0.460, acc: 0.734
[1,1148], f1_score: 0.459, acc: 0.738
[1,1404], f1_score: 0.470, acc: 0.738
[1,1660], f1_score: 0.446, acc: 0.742
[1,1916], f1_score: 0.500, acc: 0.740
[1,2172], f1_score: 0.440, acc: 0.744
[1,2428], f1_score: 0.476, acc: 0.747
[1,2684], f1_score: 0.489, acc: 0.74

[18,1208], f1_score: 0.619, acc: 0.777
[18,1464], f1_score: 0.610, acc: 0.775
[18,1720], f1_score: 0.599, acc: 0.783
[18,1976], f1_score: 0.615, acc: 0.776
[18,2232], f1_score: 0.616, acc: 0.780
[18,2488], f1_score: 0.618, acc: 0.781
[18,2744], f1_score: 0.617, acc: 0.781
[19,52], f1_score: 0.616, acc: 0.777
[19,308], f1_score: 0.619, acc: 0.774
[19,564], f1_score: 0.607, acc: 0.779
[19,820], f1_score: 0.625, acc: 0.781
[19,1076], f1_score: 0.611, acc: 0.779
[19,1332], f1_score: 0.564, acc: 0.774
[19,1588], f1_score: 0.626, acc: 0.779
[19,1844], f1_score: 0.613, acc: 0.782
[19,2100], f1_score: 0.619, acc: 0.779
[19,2356], f1_score: 0.602, acc: 0.780
[19,2612], f1_score: 0.625, acc: 0.770
[19,2868], f1_score: 0.549, acc: 0.773
[20,176], f1_score: 0.624, acc: 0.776
[20,432], f1_score: 0.590, acc: 0.779
[20,688], f1_score: 0.616, acc: 0.775
[20,944], f1_score: 0.603, acc: 0.779
[20,1200], f1_score: 0.627, acc: 0.777
[20,1456], f1_score: 0.578, acc: 0.775
[20,1712], f1_score: 0.613, acc: 0

[36,2672], f1_score: 0.613, acc: 0.770
[36,2928], f1_score: 0.603, acc: 0.768
[37,236], f1_score: 0.607, acc: 0.766
[37,492], f1_score: 0.593, acc: 0.771
[37,748], f1_score: 0.616, acc: 0.767
[37,1004], f1_score: 0.596, acc: 0.774
[37,1260], f1_score: 0.609, acc: 0.764
[37,1516], f1_score: 0.593, acc: 0.765
[37,1772], f1_score: 0.570, acc: 0.768
[37,2028], f1_score: 0.612, acc: 0.770
[37,2284], f1_score: 0.611, acc: 0.755
[37,2540], f1_score: 0.573, acc: 0.771
[37,2796], f1_score: 0.611, acc: 0.766
[38,104], f1_score: 0.589, acc: 0.772
[38,360], f1_score: 0.621, acc: 0.763
[38,616], f1_score: 0.595, acc: 0.772
[38,872], f1_score: 0.609, acc: 0.771
[38,1128], f1_score: 0.604, acc: 0.767
[38,1384], f1_score: 0.609, acc: 0.770
[38,1640], f1_score: 0.568, acc: 0.766
[38,1896], f1_score: 0.603, acc: 0.759
[38,2152], f1_score: 0.585, acc: 0.768
[38,2408], f1_score: 0.616, acc: 0.754
[38,2664], f1_score: 0.582, acc: 0.771
[38,2920], f1_score: 0.601, acc: 0.766
[39,228], f1_score: 0.588, acc: 

In [23]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_h2), sta.mean(recall_h2), sta.mean(f1score_h2), sta.mean(accuracy_h2)))

prec=0.612, recall=0.606, F1=0.608, acc=0.760


### 3 layer MLP

In [29]:
precision_h3, recall_h3, f1score_h3, accuracy_h3 = [], [], [], []

for i in range(5):
    # Init model with single hidden layer
    model = MLP(hidden_layer = 3).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*40,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_h3.append(precision)
    recall_h3.append(recall)
    f1score_h3.append(f1score)
    accuracy_h3.append(accuracy)

[0,0], f1_score: 0.389, acc: 0.551
[0,256], f1_score: 0.369, acc: 0.703
[0,512], f1_score: 0.377, acc: 0.706
[0,768], f1_score: 0.401, acc: 0.710
[0,1024], f1_score: 0.407, acc: 0.713
[0,1280], f1_score: 0.385, acc: 0.715
[0,1536], f1_score: 0.395, acc: 0.719
[0,1792], f1_score: 0.376, acc: 0.723
[0,2048], f1_score: 0.395, acc: 0.724
[0,2304], f1_score: 0.395, acc: 0.726
[0,2560], f1_score: 0.411, acc: 0.728
[0,2816], f1_score: 0.400, acc: 0.732
[1,124], f1_score: 0.453, acc: 0.729
[1,380], f1_score: 0.442, acc: 0.735
[1,636], f1_score: 0.453, acc: 0.735
[1,892], f1_score: 0.458, acc: 0.737
[1,1148], f1_score: 0.477, acc: 0.739
[1,1404], f1_score: 0.464, acc: 0.739
[1,1660], f1_score: 0.441, acc: 0.744
[1,1916], f1_score: 0.486, acc: 0.743
[1,2172], f1_score: 0.451, acc: 0.746
[1,2428], f1_score: 0.497, acc: 0.750
[1,2684], f1_score: 0.518, acc: 0.745
[1,2940], f1_score: 0.450, acc: 0.750
[2,248], f1_score: 0.461, acc: 0.749
[2,504], f1_score: 0.476, acc: 0.753
[2,760], f1_score: 0.508

[18,2232], f1_score: 0.612, acc: 0.780
[18,2488], f1_score: 0.622, acc: 0.776
[18,2744], f1_score: 0.613, acc: 0.780
[19,52], f1_score: 0.608, acc: 0.780
[19,308], f1_score: 0.622, acc: 0.782
[19,564], f1_score: 0.580, acc: 0.775
[19,820], f1_score: 0.615, acc: 0.778
[19,1076], f1_score: 0.624, acc: 0.778
[19,1332], f1_score: 0.590, acc: 0.781
[19,1588], f1_score: 0.620, acc: 0.776
[19,1844], f1_score: 0.601, acc: 0.777
[19,2100], f1_score: 0.611, acc: 0.769
[19,2356], f1_score: 0.594, acc: 0.783
[19,2612], f1_score: 0.628, acc: 0.774
[19,2868], f1_score: 0.579, acc: 0.775
[20,176], f1_score: 0.623, acc: 0.770
[20,432], f1_score: 0.604, acc: 0.783
[20,688], f1_score: 0.623, acc: 0.779
[20,944], f1_score: 0.605, acc: 0.783
[20,1200], f1_score: 0.615, acc: 0.781
[20,1456], f1_score: 0.585, acc: 0.778
[20,1712], f1_score: 0.617, acc: 0.778
[20,1968], f1_score: 0.601, acc: 0.769
[20,2224], f1_score: 0.608, acc: 0.772
[20,2480], f1_score: 0.606, acc: 0.780
[20,2736], f1_score: 0.609, acc: 0

[37,748], f1_score: 0.598, acc: 0.767
[37,1004], f1_score: 0.598, acc: 0.768
[37,1260], f1_score: 0.611, acc: 0.760
[37,1516], f1_score: 0.594, acc: 0.775
[37,1772], f1_score: 0.598, acc: 0.771
[37,2028], f1_score: 0.609, acc: 0.768
[37,2284], f1_score: 0.606, acc: 0.760
[37,2540], f1_score: 0.602, acc: 0.773
[37,2796], f1_score: 0.593, acc: 0.764
[38,104], f1_score: 0.604, acc: 0.768
[38,360], f1_score: 0.596, acc: 0.769
[38,616], f1_score: 0.608, acc: 0.765
[38,872], f1_score: 0.563, acc: 0.764
[38,1128], f1_score: 0.607, acc: 0.759
[38,1384], f1_score: 0.600, acc: 0.770
[38,1640], f1_score: 0.592, acc: 0.771
[38,1896], f1_score: 0.593, acc: 0.764
[38,2152], f1_score: 0.592, acc: 0.760
[38,2408], f1_score: 0.614, acc: 0.763
[38,2664], f1_score: 0.605, acc: 0.770
[38,2920], f1_score: 0.572, acc: 0.764
[39,228], f1_score: 0.599, acc: 0.765
[39,484], f1_score: 0.610, acc: 0.770
[39,740], f1_score: 0.600, acc: 0.764
[39,996], f1_score: 0.584, acc: 0.770
[39,1252], f1_score: 0.607, acc: 0

[14,200], f1_score: 0.629, acc: 0.777
[14,456], f1_score: 0.588, acc: 0.782
[14,712], f1_score: 0.570, acc: 0.778
[14,968], f1_score: 0.625, acc: 0.782
[14,1224], f1_score: 0.570, acc: 0.776
[14,1480], f1_score: 0.595, acc: 0.780
[14,1736], f1_score: 0.584, acc: 0.778
[14,1992], f1_score: 0.635, acc: 0.773
[14,2248], f1_score: 0.596, acc: 0.781
[14,2504], f1_score: 0.632, acc: 0.774
[14,2760], f1_score: 0.614, acc: 0.781
[15,68], f1_score: 0.565, acc: 0.776
[15,324], f1_score: 0.601, acc: 0.777
[15,580], f1_score: 0.629, acc: 0.769
[15,836], f1_score: 0.584, acc: 0.781
[15,1092], f1_score: 0.615, acc: 0.771
[15,1348], f1_score: 0.568, acc: 0.780
[15,1604], f1_score: 0.599, acc: 0.781
[15,1860], f1_score: 0.598, acc: 0.782
[15,2116], f1_score: 0.618, acc: 0.782
[15,2372], f1_score: 0.627, acc: 0.782
[15,2628], f1_score: 0.618, acc: 0.779
[15,2884], f1_score: 0.621, acc: 0.777
[16,192], f1_score: 0.608, acc: 0.776
[16,448], f1_score: 0.528, acc: 0.770
[16,704], f1_score: 0.625, acc: 0.76

[32,1664], f1_score: 0.584, acc: 0.768
[32,1920], f1_score: 0.604, acc: 0.772
[32,2176], f1_score: 0.592, acc: 0.770
[32,2432], f1_score: 0.609, acc: 0.767
[32,2688], f1_score: 0.598, acc: 0.769
[32,2944], f1_score: 0.607, acc: 0.774
[33,252], f1_score: 0.563, acc: 0.767
[33,508], f1_score: 0.619, acc: 0.757
[33,764], f1_score: 0.546, acc: 0.765
[33,1020], f1_score: 0.607, acc: 0.769
[33,1276], f1_score: 0.612, acc: 0.762
[33,1532], f1_score: 0.577, acc: 0.772
[33,1788], f1_score: 0.607, acc: 0.767
[33,2044], f1_score: 0.611, acc: 0.759
[33,2300], f1_score: 0.583, acc: 0.772
[33,2556], f1_score: 0.609, acc: 0.771
[33,2812], f1_score: 0.606, acc: 0.766
[34,120], f1_score: 0.584, acc: 0.772
[34,376], f1_score: 0.611, acc: 0.772
[34,632], f1_score: 0.601, acc: 0.764
[34,888], f1_score: 0.590, acc: 0.769
[34,1144], f1_score: 0.597, acc: 0.766
[34,1400], f1_score: 0.605, acc: 0.767
[34,1656], f1_score: 0.566, acc: 0.769
[34,1912], f1_score: 0.617, acc: 0.771
[34,2168], f1_score: 0.609, acc:

[9,1116], f1_score: 0.626, acc: 0.778
[9,1372], f1_score: 0.629, acc: 0.780
[9,1628], f1_score: 0.628, acc: 0.782
[9,1884], f1_score: 0.614, acc: 0.779
[9,2140], f1_score: 0.587, acc: 0.783
[9,2396], f1_score: 0.616, acc: 0.784
[9,2652], f1_score: 0.626, acc: 0.782
[9,2908], f1_score: 0.567, acc: 0.780
[10,216], f1_score: 0.586, acc: 0.782
[10,472], f1_score: 0.608, acc: 0.784
[10,728], f1_score: 0.572, acc: 0.780
[10,984], f1_score: 0.572, acc: 0.776
[10,1240], f1_score: 0.561, acc: 0.777
[10,1496], f1_score: 0.611, acc: 0.780
[10,1752], f1_score: 0.596, acc: 0.784
[10,2008], f1_score: 0.624, acc: 0.785
[10,2264], f1_score: 0.610, acc: 0.785
[10,2520], f1_score: 0.616, acc: 0.782
[10,2776], f1_score: 0.602, acc: 0.783
[11,84], f1_score: 0.624, acc: 0.772
[11,340], f1_score: 0.596, acc: 0.778
[11,596], f1_score: 0.603, acc: 0.780
[11,852], f1_score: 0.617, acc: 0.770
[11,1108], f1_score: 0.613, acc: 0.782
[11,1364], f1_score: 0.568, acc: 0.780
[11,1620], f1_score: 0.578, acc: 0.781
[11

[27,2580], f1_score: 0.580, acc: 0.768
[27,2836], f1_score: 0.617, acc: 0.768
[28,144], f1_score: 0.500, acc: 0.754
[28,400], f1_score: 0.619, acc: 0.750
[28,656], f1_score: 0.541, acc: 0.765
[28,912], f1_score: 0.618, acc: 0.754
[28,1168], f1_score: 0.595, acc: 0.774
[28,1424], f1_score: 0.595, acc: 0.776
[28,1680], f1_score: 0.599, acc: 0.772
[28,1936], f1_score: 0.613, acc: 0.764
[28,2192], f1_score: 0.583, acc: 0.772
[28,2448], f1_score: 0.618, acc: 0.776
[28,2704], f1_score: 0.602, acc: 0.766
[29,12], f1_score: 0.610, acc: 0.772
[29,268], f1_score: 0.588, acc: 0.767
[29,524], f1_score: 0.621, acc: 0.765
[29,780], f1_score: 0.542, acc: 0.763
[29,1036], f1_score: 0.615, acc: 0.752
[29,1292], f1_score: 0.590, acc: 0.771
[29,1548], f1_score: 0.591, acc: 0.773
[29,1804], f1_score: 0.606, acc: 0.769
[29,2060], f1_score: 0.600, acc: 0.768
[29,2316], f1_score: 0.597, acc: 0.762
[29,2572], f1_score: 0.582, acc: 0.772
[29,2828], f1_score: 0.609, acc: 0.768
[30,136], f1_score: 0.574, acc: 0.

[4,1520], f1_score: 0.560, acc: 0.774
[4,1776], f1_score: 0.583, acc: 0.780
[4,2032], f1_score: 0.592, acc: 0.777
[4,2288], f1_score: 0.591, acc: 0.780
[4,2544], f1_score: 0.611, acc: 0.766
[4,2800], f1_score: 0.570, acc: 0.775
[5,108], f1_score: 0.602, acc: 0.778
[5,364], f1_score: 0.562, acc: 0.775
[5,620], f1_score: 0.587, acc: 0.781
[5,876], f1_score: 0.594, acc: 0.778
[5,1132], f1_score: 0.587, acc: 0.776
[5,1388], f1_score: 0.592, acc: 0.778
[5,1644], f1_score: 0.594, acc: 0.782
[5,1900], f1_score: 0.612, acc: 0.771
[5,2156], f1_score: 0.599, acc: 0.783
[5,2412], f1_score: 0.616, acc: 0.781
[5,2668], f1_score: 0.616, acc: 0.779
[5,2924], f1_score: 0.582, acc: 0.782
[6,232], f1_score: 0.582, acc: 0.781
[6,488], f1_score: 0.611, acc: 0.782
[6,744], f1_score: 0.619, acc: 0.774
[6,1000], f1_score: 0.615, acc: 0.780
[6,1256], f1_score: 0.604, acc: 0.782
[6,1512], f1_score: 0.612, acc: 0.781
[6,1768], f1_score: 0.545, acc: 0.775
[6,2024], f1_score: 0.565, acc: 0.779
[6,2280], f1_score:

[23,292], f1_score: 0.616, acc: 0.770
[23,548], f1_score: 0.591, acc: 0.782
[23,804], f1_score: 0.620, acc: 0.770
[23,1060], f1_score: 0.599, acc: 0.777
[23,1316], f1_score: 0.623, acc: 0.771
[23,1572], f1_score: 0.568, acc: 0.771
[23,1828], f1_score: 0.622, acc: 0.776
[23,2084], f1_score: 0.610, acc: 0.779
[23,2340], f1_score: 0.618, acc: 0.780
[23,2596], f1_score: 0.615, acc: 0.775
[23,2852], f1_score: 0.604, acc: 0.780
[24,160], f1_score: 0.611, acc: 0.776
[24,416], f1_score: 0.624, acc: 0.775
[24,672], f1_score: 0.550, acc: 0.772
[24,928], f1_score: 0.620, acc: 0.762
[24,1184], f1_score: 0.590, acc: 0.779
[24,1440], f1_score: 0.616, acc: 0.775
[24,1696], f1_score: 0.589, acc: 0.775
[24,1952], f1_score: 0.612, acc: 0.768
[24,2208], f1_score: 0.615, acc: 0.775
[24,2464], f1_score: 0.578, acc: 0.774
[24,2720], f1_score: 0.622, acc: 0.765
[25,28], f1_score: 0.523, acc: 0.765
[25,284], f1_score: 0.621, acc: 0.769
[25,540], f1_score: 0.604, acc: 0.775
[25,796], f1_score: 0.596, acc: 0.77

[41,1756], f1_score: 0.591, acc: 0.769
[41,2012], f1_score: 0.596, acc: 0.766
Evaluating...
prec=0.638, recall=0.558, F1=0.596, acc=0.766
[0,0], f1_score: 0.385, acc: 0.550
[0,256], f1_score: 0.370, acc: 0.703
[0,512], f1_score: 0.378, acc: 0.706
[0,768], f1_score: 0.398, acc: 0.709
[0,1024], f1_score: 0.407, acc: 0.712
[0,1280], f1_score: 0.381, acc: 0.715
[0,1536], f1_score: 0.400, acc: 0.718
[0,1792], f1_score: 0.378, acc: 0.722
[0,2048], f1_score: 0.398, acc: 0.724
[0,2304], f1_score: 0.408, acc: 0.725
[0,2560], f1_score: 0.418, acc: 0.728
[0,2816], f1_score: 0.390, acc: 0.732
[1,124], f1_score: 0.450, acc: 0.729
[1,380], f1_score: 0.435, acc: 0.736
[1,636], f1_score: 0.461, acc: 0.735
[1,892], f1_score: 0.452, acc: 0.736
[1,1148], f1_score: 0.477, acc: 0.738
[1,1404], f1_score: 0.460, acc: 0.740
[1,1660], f1_score: 0.425, acc: 0.743
[1,1916], f1_score: 0.490, acc: 0.743
[1,2172], f1_score: 0.431, acc: 0.744
[1,2428], f1_score: 0.491, acc: 0.750
[1,2684], f1_score: 0.506, acc: 0.75

[18,1208], f1_score: 0.622, acc: 0.784
[18,1464], f1_score: 0.592, acc: 0.780
[18,1720], f1_score: 0.608, acc: 0.781
[18,1976], f1_score: 0.627, acc: 0.774
[18,2232], f1_score: 0.595, acc: 0.778
[18,2488], f1_score: 0.618, acc: 0.776
[18,2744], f1_score: 0.598, acc: 0.780
[19,52], f1_score: 0.618, acc: 0.781
[19,308], f1_score: 0.616, acc: 0.776
[19,564], f1_score: 0.546, acc: 0.774
[19,820], f1_score: 0.620, acc: 0.771
[19,1076], f1_score: 0.604, acc: 0.781
[19,1332], f1_score: 0.612, acc: 0.784
[19,1588], f1_score: 0.592, acc: 0.768
[19,1844], f1_score: 0.587, acc: 0.777
[19,2100], f1_score: 0.620, acc: 0.769
[19,2356], f1_score: 0.579, acc: 0.777
[19,2612], f1_score: 0.627, acc: 0.768
[19,2868], f1_score: 0.528, acc: 0.767
[20,176], f1_score: 0.608, acc: 0.775
[20,432], f1_score: 0.615, acc: 0.777
[20,688], f1_score: 0.608, acc: 0.777
[20,944], f1_score: 0.614, acc: 0.777
[20,1200], f1_score: 0.608, acc: 0.781
[20,1456], f1_score: 0.597, acc: 0.782
[20,1712], f1_score: 0.619, acc: 0

[36,2672], f1_score: 0.596, acc: 0.765
[36,2928], f1_score: 0.606, acc: 0.763
[37,236], f1_score: 0.578, acc: 0.769
[37,492], f1_score: 0.597, acc: 0.769
[37,748], f1_score: 0.619, acc: 0.758
[37,1004], f1_score: 0.569, acc: 0.771
[37,1260], f1_score: 0.609, acc: 0.759
[37,1516], f1_score: 0.593, acc: 0.775
[37,1772], f1_score: 0.597, acc: 0.768
[37,2028], f1_score: 0.602, acc: 0.763
[37,2284], f1_score: 0.600, acc: 0.762
[37,2540], f1_score: 0.599, acc: 0.772
[37,2796], f1_score: 0.582, acc: 0.760
[38,104], f1_score: 0.611, acc: 0.762
[38,360], f1_score: 0.575, acc: 0.767
[38,616], f1_score: 0.606, acc: 0.766
[38,872], f1_score: 0.599, acc: 0.768
[38,1128], f1_score: 0.599, acc: 0.764
[38,1384], f1_score: 0.607, acc: 0.764
[38,1640], f1_score: 0.577, acc: 0.773
[38,1896], f1_score: 0.592, acc: 0.748
[38,2152], f1_score: 0.590, acc: 0.761
[38,2408], f1_score: 0.596, acc: 0.766
[38,2664], f1_score: 0.610, acc: 0.770
[38,2920], f1_score: 0.577, acc: 0.765
[39,228], f1_score: 0.606, acc: 

In [30]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_h3), sta.mean(recall_h3), sta.mean(f1score_h3), sta.mean(accuracy_h3)))

prec=0.616, recall=0.592, F1=0.603, acc=0.760


### LASSO

In [31]:
precision_h0, recall_h0, f1score_h0, accuracy_h0 = [], [], [], []

for i in range(5):
    # Init model with single hidden layer
    model = MLP(hidden_layer = 0).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*40,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_h0.append(precision)
    recall_h0.append(recall)
    f1score_h0.append(f1score)
    accuracy_h0.append(accuracy)

[0,0], f1_score: 0.376, acc: 0.505
[0,256], f1_score: 0.371, acc: 0.581
[0,512], f1_score: 0.377, acc: 0.625
[0,768], f1_score: 0.384, acc: 0.641
[0,1024], f1_score: 0.399, acc: 0.641
[0,1280], f1_score: 0.389, acc: 0.642
[0,1536], f1_score: 0.398, acc: 0.647
[0,1792], f1_score: 0.389, acc: 0.648
[0,2048], f1_score: 0.414, acc: 0.646
[0,2304], f1_score: 0.386, acc: 0.656
[0,2560], f1_score: 0.426, acc: 0.646
[0,2816], f1_score: 0.401, acc: 0.658
[1,124], f1_score: 0.420, acc: 0.655
[1,380], f1_score: 0.427, acc: 0.658
[1,636], f1_score: 0.412, acc: 0.659
[1,892], f1_score: 0.436, acc: 0.659
[1,1148], f1_score: 0.424, acc: 0.660
[1,1404], f1_score: 0.412, acc: 0.669
[1,1660], f1_score: 0.436, acc: 0.663
[1,1916], f1_score: 0.430, acc: 0.667
[1,2172], f1_score: 0.413, acc: 0.670
[1,2428], f1_score: 0.440, acc: 0.670
[1,2684], f1_score: 0.445, acc: 0.669
[1,2940], f1_score: 0.411, acc: 0.677
[2,248], f1_score: 0.452, acc: 0.676
[2,504], f1_score: 0.438, acc: 0.679
[2,760], f1_score: 0.451

[18,2232], f1_score: 0.575, acc: 0.777
[18,2488], f1_score: 0.585, acc: 0.778
[18,2744], f1_score: 0.580, acc: 0.778
[19,52], f1_score: 0.574, acc: 0.778
[19,308], f1_score: 0.584, acc: 0.777
[19,564], f1_score: 0.579, acc: 0.778
[19,820], f1_score: 0.572, acc: 0.777
[19,1076], f1_score: 0.583, acc: 0.777
[19,1332], f1_score: 0.573, acc: 0.777
[19,1588], f1_score: 0.560, acc: 0.777
[19,1844], f1_score: 0.585, acc: 0.778
[19,2100], f1_score: 0.589, acc: 0.777
[19,2356], f1_score: 0.581, acc: 0.779
[19,2612], f1_score: 0.584, acc: 0.778
[19,2868], f1_score: 0.578, acc: 0.778
[20,176], f1_score: 0.579, acc: 0.778
[20,432], f1_score: 0.586, acc: 0.778
[20,688], f1_score: 0.578, acc: 0.778
[20,944], f1_score: 0.575, acc: 0.778
[20,1200], f1_score: 0.589, acc: 0.778
[20,1456], f1_score: 0.565, acc: 0.778
[20,1712], f1_score: 0.570, acc: 0.778
[20,1968], f1_score: 0.595, acc: 0.777
[20,2224], f1_score: 0.582, acc: 0.779
[20,2480], f1_score: 0.587, acc: 0.780
[20,2736], f1_score: 0.585, acc: 0

[37,748], f1_score: 0.594, acc: 0.783
[37,1004], f1_score: 0.587, acc: 0.783
[37,1260], f1_score: 0.598, acc: 0.782
[37,1516], f1_score: 0.589, acc: 0.783
[37,1772], f1_score: 0.586, acc: 0.782
[37,2028], f1_score: 0.596, acc: 0.782
[37,2284], f1_score: 0.596, acc: 0.782
[37,2540], f1_score: 0.603, acc: 0.783
[37,2796], f1_score: 0.596, acc: 0.783
[38,104], f1_score: 0.590, acc: 0.783
[38,360], f1_score: 0.598, acc: 0.782
[38,616], f1_score: 0.600, acc: 0.782
[38,872], f1_score: 0.591, acc: 0.783
[38,1128], f1_score: 0.588, acc: 0.782
[38,1384], f1_score: 0.598, acc: 0.782
[38,1640], f1_score: 0.587, acc: 0.782
[38,1896], f1_score: 0.590, acc: 0.783
[38,2152], f1_score: 0.597, acc: 0.781
[38,2408], f1_score: 0.599, acc: 0.782
[38,2664], f1_score: 0.601, acc: 0.783
[38,2920], f1_score: 0.592, acc: 0.783
[39,228], f1_score: 0.592, acc: 0.783
[39,484], f1_score: 0.600, acc: 0.782
[39,740], f1_score: 0.599, acc: 0.783
[39,996], f1_score: 0.585, acc: 0.783
[39,1252], f1_score: 0.593, acc: 0

[14,200], f1_score: 0.565, acc: 0.769
[14,456], f1_score: 0.557, acc: 0.770
[14,712], f1_score: 0.557, acc: 0.770
[14,968], f1_score: 0.561, acc: 0.770
[14,1224], f1_score: 0.563, acc: 0.770
[14,1480], f1_score: 0.538, acc: 0.770
[14,1736], f1_score: 0.565, acc: 0.771
[14,1992], f1_score: 0.571, acc: 0.770
[14,2248], f1_score: 0.560, acc: 0.771
[14,2504], f1_score: 0.568, acc: 0.772
[14,2760], f1_score: 0.568, acc: 0.772
[15,68], f1_score: 0.555, acc: 0.772
[15,324], f1_score: 0.576, acc: 0.770
[15,580], f1_score: 0.555, acc: 0.772
[15,836], f1_score: 0.565, acc: 0.772
[15,1092], f1_score: 0.567, acc: 0.772
[15,1348], f1_score: 0.553, acc: 0.771
[15,1604], f1_score: 0.552, acc: 0.772
[15,1860], f1_score: 0.576, acc: 0.772
[15,2116], f1_score: 0.569, acc: 0.773
[15,2372], f1_score: 0.570, acc: 0.773
[15,2628], f1_score: 0.569, acc: 0.773
[15,2884], f1_score: 0.565, acc: 0.773
[16,192], f1_score: 0.568, acc: 0.773
[16,448], f1_score: 0.574, acc: 0.773
[16,704], f1_score: 0.560, acc: 0.77

[32,1664], f1_score: 0.587, acc: 0.782
[32,1920], f1_score: 0.587, acc: 0.782
[32,2176], f1_score: 0.596, acc: 0.781
[32,2432], f1_score: 0.604, acc: 0.781
[32,2688], f1_score: 0.588, acc: 0.782
[32,2944], f1_score: 0.592, acc: 0.782
[33,252], f1_score: 0.598, acc: 0.782
[33,508], f1_score: 0.592, acc: 0.782
[33,764], f1_score: 0.595, acc: 0.782
[33,1020], f1_score: 0.593, acc: 0.782
[33,1276], f1_score: 0.591, acc: 0.782
[33,1532], f1_score: 0.587, acc: 0.782
[33,1788], f1_score: 0.586, acc: 0.782
[33,2044], f1_score: 0.591, acc: 0.782
[33,2300], f1_score: 0.600, acc: 0.781
[33,2556], f1_score: 0.601, acc: 0.782
[33,2812], f1_score: 0.586, acc: 0.782
[34,120], f1_score: 0.597, acc: 0.782
[34,376], f1_score: 0.598, acc: 0.781
[34,632], f1_score: 0.592, acc: 0.782
[34,888], f1_score: 0.595, acc: 0.782
[34,1144], f1_score: 0.592, acc: 0.782
[34,1400], f1_score: 0.589, acc: 0.782
[34,1656], f1_score: 0.587, acc: 0.782
[34,1912], f1_score: 0.588, acc: 0.782
[34,2168], f1_score: 0.595, acc:

[9,1116], f1_score: 0.530, acc: 0.755
[9,1372], f1_score: 0.505, acc: 0.757
[9,1628], f1_score: 0.534, acc: 0.756
[9,1884], f1_score: 0.527, acc: 0.757
[9,2140], f1_score: 0.534, acc: 0.758
[9,2396], f1_score: 0.529, acc: 0.759
[9,2652], f1_score: 0.542, acc: 0.758
[9,2908], f1_score: 0.510, acc: 0.759
[10,216], f1_score: 0.546, acc: 0.758
[10,472], f1_score: 0.521, acc: 0.760
[10,728], f1_score: 0.539, acc: 0.760
[10,984], f1_score: 0.539, acc: 0.760
[10,1240], f1_score: 0.535, acc: 0.760
[10,1496], f1_score: 0.520, acc: 0.762
[10,1752], f1_score: 0.544, acc: 0.762
[10,2008], f1_score: 0.550, acc: 0.761
[10,2264], f1_score: 0.532, acc: 0.762
[10,2520], f1_score: 0.548, acc: 0.763
[10,2776], f1_score: 0.536, acc: 0.764
[11,84], f1_score: 0.539, acc: 0.763
[11,340], f1_score: 0.550, acc: 0.763
[11,596], f1_score: 0.539, acc: 0.764
[11,852], f1_score: 0.545, acc: 0.764
[11,1108], f1_score: 0.547, acc: 0.764
[11,1364], f1_score: 0.525, acc: 0.765
[11,1620], f1_score: 0.544, acc: 0.764
[11

[27,2580], f1_score: 0.588, acc: 0.782
[27,2836], f1_score: 0.591, acc: 0.781
[28,144], f1_score: 0.589, acc: 0.781
[28,400], f1_score: 0.588, acc: 0.781
[28,656], f1_score: 0.594, acc: 0.781
[28,912], f1_score: 0.588, acc: 0.781
[28,1168], f1_score: 0.584, acc: 0.781
[28,1424], f1_score: 0.589, acc: 0.781
[28,1680], f1_score: 0.581, acc: 0.780
[28,1936], f1_score: 0.584, acc: 0.781
[28,2192], f1_score: 0.600, acc: 0.779
[28,2448], f1_score: 0.597, acc: 0.781
[28,2704], f1_score: 0.587, acc: 0.782
[29,12], f1_score: 0.592, acc: 0.781
[29,268], f1_score: 0.589, acc: 0.781
[29,524], f1_score: 0.590, acc: 0.782
[29,780], f1_score: 0.595, acc: 0.781
[29,1036], f1_score: 0.585, acc: 0.781
[29,1292], f1_score: 0.588, acc: 0.781
[29,1548], f1_score: 0.587, acc: 0.781
[29,1804], f1_score: 0.581, acc: 0.781
[29,2060], f1_score: 0.593, acc: 0.781
[29,2316], f1_score: 0.602, acc: 0.780
[29,2572], f1_score: 0.591, acc: 0.782
[29,2828], f1_score: 0.590, acc: 0.782
[30,136], f1_score: 0.593, acc: 0.

[4,1520], f1_score: 0.478, acc: 0.714
[4,1776], f1_score: 0.454, acc: 0.720
[4,2032], f1_score: 0.498, acc: 0.715
[4,2288], f1_score: 0.451, acc: 0.723
[4,2544], f1_score: 0.498, acc: 0.717
[4,2800], f1_score: 0.466, acc: 0.724
[5,108], f1_score: 0.493, acc: 0.721
[5,364], f1_score: 0.484, acc: 0.725
[5,620], f1_score: 0.485, acc: 0.725
[5,876], f1_score: 0.494, acc: 0.725
[5,1132], f1_score: 0.486, acc: 0.726
[5,1388], f1_score: 0.470, acc: 0.730
[5,1644], f1_score: 0.496, acc: 0.727
[5,1900], f1_score: 0.492, acc: 0.729
[5,2156], f1_score: 0.474, acc: 0.732
[5,2412], f1_score: 0.500, acc: 0.731
[5,2668], f1_score: 0.498, acc: 0.732
[5,2924], f1_score: 0.471, acc: 0.734
[6,232], f1_score: 0.504, acc: 0.734
[6,488], f1_score: 0.490, acc: 0.736
[6,744], f1_score: 0.501, acc: 0.735
[6,1000], f1_score: 0.509, acc: 0.734
[6,1256], f1_score: 0.490, acc: 0.738
[6,1512], f1_score: 0.498, acc: 0.738
[6,1768], f1_score: 0.485, acc: 0.740
[6,2024], f1_score: 0.518, acc: 0.738
[6,2280], f1_score:

[23,292], f1_score: 0.584, acc: 0.780
[23,548], f1_score: 0.587, acc: 0.780
[23,804], f1_score: 0.581, acc: 0.779
[23,1060], f1_score: 0.582, acc: 0.780
[23,1316], f1_score: 0.587, acc: 0.779
[23,1572], f1_score: 0.564, acc: 0.779
[23,1828], f1_score: 0.581, acc: 0.780
[23,2084], f1_score: 0.597, acc: 0.778
[23,2340], f1_score: 0.585, acc: 0.781
[23,2596], f1_score: 0.587, acc: 0.781
[23,2852], f1_score: 0.585, acc: 0.780
[24,160], f1_score: 0.579, acc: 0.780
[24,416], f1_score: 0.587, acc: 0.780
[24,672], f1_score: 0.586, acc: 0.780
[24,928], f1_score: 0.580, acc: 0.780
[24,1184], f1_score: 0.589, acc: 0.780
[24,1440], f1_score: 0.579, acc: 0.779
[24,1696], f1_score: 0.567, acc: 0.779
[24,1952], f1_score: 0.594, acc: 0.780
[24,2208], f1_score: 0.591, acc: 0.779
[24,2464], f1_score: 0.587, acc: 0.781
[24,2720], f1_score: 0.589, acc: 0.781
[25,28], f1_score: 0.582, acc: 0.780
[25,284], f1_score: 0.584, acc: 0.780
[25,540], f1_score: 0.589, acc: 0.780
[25,796], f1_score: 0.585, acc: 0.78

[41,1756], f1_score: 0.588, acc: 0.782
[41,2012], f1_score: 0.584, acc: 0.782
Evaluating...
prec=0.708, recall=0.496, F1=0.584, acc=0.782
[0,0], f1_score: 0.384, acc: 0.503
[0,256], f1_score: 0.379, acc: 0.582
[0,512], f1_score: 0.380, acc: 0.625
[0,768], f1_score: 0.381, acc: 0.639
[0,1024], f1_score: 0.397, acc: 0.640
[0,1280], f1_score: 0.387, acc: 0.642
[0,1536], f1_score: 0.395, acc: 0.645
[0,1792], f1_score: 0.386, acc: 0.647
[0,2048], f1_score: 0.411, acc: 0.642
[0,2304], f1_score: 0.384, acc: 0.654
[0,2560], f1_score: 0.422, acc: 0.644
[0,2816], f1_score: 0.397, acc: 0.655
[1,124], f1_score: 0.416, acc: 0.652
[1,380], f1_score: 0.422, acc: 0.654
[1,636], f1_score: 0.407, acc: 0.655
[1,892], f1_score: 0.431, acc: 0.655
[1,1148], f1_score: 0.420, acc: 0.656
[1,1404], f1_score: 0.407, acc: 0.665
[1,1660], f1_score: 0.431, acc: 0.659
[1,1916], f1_score: 0.423, acc: 0.662
[1,2172], f1_score: 0.407, acc: 0.665
[1,2428], f1_score: 0.435, acc: 0.665
[1,2684], f1_score: 0.438, acc: 0.66

[18,1208], f1_score: 0.582, acc: 0.776
[18,1464], f1_score: 0.556, acc: 0.776
[18,1720], f1_score: 0.569, acc: 0.776
[18,1976], f1_score: 0.590, acc: 0.775
[18,2232], f1_score: 0.576, acc: 0.777
[18,2488], f1_score: 0.583, acc: 0.778
[18,2744], f1_score: 0.578, acc: 0.777
[19,52], f1_score: 0.570, acc: 0.777
[19,308], f1_score: 0.583, acc: 0.777
[19,564], f1_score: 0.580, acc: 0.777
[19,820], f1_score: 0.574, acc: 0.777
[19,1076], f1_score: 0.584, acc: 0.777
[19,1332], f1_score: 0.572, acc: 0.776
[19,1588], f1_score: 0.558, acc: 0.776
[19,1844], f1_score: 0.583, acc: 0.778
[19,2100], f1_score: 0.589, acc: 0.777
[19,2356], f1_score: 0.582, acc: 0.778
[19,2612], f1_score: 0.582, acc: 0.778
[19,2868], f1_score: 0.576, acc: 0.778
[20,176], f1_score: 0.575, acc: 0.778
[20,432], f1_score: 0.585, acc: 0.778
[20,688], f1_score: 0.578, acc: 0.778
[20,944], f1_score: 0.578, acc: 0.778
[20,1200], f1_score: 0.588, acc: 0.778
[20,1456], f1_score: 0.563, acc: 0.777
[20,1712], f1_score: 0.568, acc: 0

[36,2672], f1_score: 0.599, acc: 0.782
[36,2928], f1_score: 0.586, acc: 0.782
[37,236], f1_score: 0.591, acc: 0.782
[37,492], f1_score: 0.597, acc: 0.781
[37,748], f1_score: 0.590, acc: 0.782
[37,1004], f1_score: 0.589, acc: 0.782
[37,1260], f1_score: 0.596, acc: 0.782
[37,1516], f1_score: 0.585, acc: 0.782
[37,1772], f1_score: 0.583, acc: 0.781
[37,2028], f1_score: 0.593, acc: 0.781
[37,2284], f1_score: 0.595, acc: 0.781
[37,2540], f1_score: 0.602, acc: 0.782
[37,2796], f1_score: 0.594, acc: 0.782
[38,104], f1_score: 0.587, acc: 0.782
[38,360], f1_score: 0.594, acc: 0.781
[38,616], f1_score: 0.597, acc: 0.782
[38,872], f1_score: 0.588, acc: 0.782
[38,1128], f1_score: 0.590, acc: 0.782
[38,1384], f1_score: 0.596, acc: 0.782
[38,1640], f1_score: 0.582, acc: 0.781
[38,1896], f1_score: 0.588, acc: 0.782
[38,2152], f1_score: 0.595, acc: 0.781
[38,2408], f1_score: 0.599, acc: 0.781
[38,2664], f1_score: 0.599, acc: 0.781
[38,2920], f1_score: 0.591, acc: 0.782
[39,228], f1_score: 0.588, acc: 

In [32]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_h0), sta.mean(recall_h0), sta.mean(f1score_h0), sta.mean(accuracy_h0)))

prec=0.708, recall=0.496, F1=0.584, acc=0.782
