In [1]:
import sys
sys.executable

'C:\\ProgramData\\Anaconda3\\envs\\btc2\\python.exe'

In [2]:
import os
import argparse
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from utils import *
from base import ModelBase
import statistics as sta
from types import SimpleNamespace

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class GIT(ModelBase):

    def __init__(self, args, **kwargs):

        super(GIT, self).__init__(args, **kwargs)
        self.build()

    def build(self):

        self.layer_sga_emb = nn.Embedding(
            num_embeddings=self.sga_size+1,
            embedding_dim=self.embedding_size,
            padding_idx=0)

        self.layer_can_emb = nn.Embedding(
            num_embeddings=self.can_size+1,
            embedding_dim=self.embedding_size,
            padding_idx=0)

        self.layer_w_0 = nn.Linear(
            in_features=self.embedding_size,
            out_features=self.attention_size,
            bias=True)

        self.layer_beta = nn.Linear(
            in_features=self.attention_size,
            out_features=self.attention_head,
            bias=True)

        self.layer_dropout_1 = nn.Dropout(p=self.dropout_rate)

        self.layer_w_1 = nn.Linear(
            in_features=self.embedding_size,
            out_features=self.hidden_size,
            bias=True)

        self.layer_dropout_2 = nn.Dropout(p=self.dropout_rate)

        self.layer_w_2 = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.deg_size,
            bias=True)-

        if self.initializtion:
            gene_emb_pretrain = np.load(os.path.join(self.input_dir, "gene_emb_pretrain.npy"))
            self.layer_sga_emb.weight.data.copy_(torch.from_numpy(gene_emb_pretrain))

        self.optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay)


    def forward(self, sga_index, can_index):
   
        # cancer type embedding
        emb_can = self.layer_can_emb(can_index)
        emb_can = emb_can.view(-1,self.embedding_size)

        # gene embedings
        E_t = self.layer_sga_emb(sga_index)

        # squeeze and tanh-curve the gene embeddings
        E_t_flatten = E_t.view(-1, self.embedding_size)
        E_t1_flatten = torch.tanh( self.layer_w_0(E_t_flatten) )

        # multiplied by attention heads
        E_t2_flatten = self.layer_beta(E_t1_flatten)
        E_t2 = E_t2_flatten.view(-1, self.num_max_sga, self.attention_head)

        # normalize by softmax
        E_t2 = E_t2.permute(1,0,2)
        A = F.softmax(E_t2)
        A = A.permute(1,0,2)

        if self.attention:
          # multi-head attention weighted sga embedding:
            emb_sga = torch.sum( torch.bmm( A.permute(0,2,1), E_t ), dim=1)
            emb_sga = emb_sga.view(-1,self.embedding_size)
        else:
          # if not using attention, simply sum up SGA embeddings
            emb_sga = torch.sum(E_t, dim=1)
            emb_sga = emb_sga.view(-1, self.embedding_size)

        # if use cancer type input, add cancer type embedding
        if self.cancer_type:
            emb_tmr = emb_can+emb_sga
        else:
            emb_tmr = emb_sga

        # MLP decoder
        emb_tmr_relu = self.layer_dropout_1(emb_tmr)
        hid_tmr = self.layer_w_1(emb_tmr_relu)
        hid_tmr_relu = self.layer_dropout_2(hid_tmr)
        
        preds = F.tanh(self.layer_w_2(hid_tmr_relu))

        # attention weights
        attn_wt = torch.sum(A, dim=2)
        attn_wt = attn_wt.view(-1, self.num_max_sga)

        return preds, hid_tmr, emb_tmr, emb_sga, attn_wt


    def train(self, train_set, test_set,
            batch_size=None, test_batch_size=None,
            max_iter=None, max_fscore=None,
            test_inc_size=None, **kwargs):

        for iter_train in range(0, max_iter+1, batch_size):
            batch_set = get_minibatch(train_set, iter_train, batch_size,batch_type="train")
            preds, _, _, _, _ = self.forward(batch_set["sga"].to(device), batch_set["can"].to(device))
            labels = batch_set["deg"].to(device)

            self.optimizer.zero_grad()
            loss = -torch.log( self.epsilon + 1 - torch.abs(preds - labels) / 2 ).mean()
            loss.backward()
            self.optimizer.step()

            if test_inc_size and (iter_train % test_inc_size == 0):
                labels, preds, _, _, _, _, _ = self.test(test_set, test_batch_size)
                precision, recall, f1score, accuracy = evaluate(
                    labels, preds, epsilon=self.epsilon)
                print("[%d,%d], precision: %.3f, acc: %.3f"% (iter_train//len(train_set["can"]),
                      iter_train%len(train_set["can"]), precision, accuracy))

                if f1score >= max_fscore:
                    break

        #self.save_model(os.path.join(self.output_dir, "trained_model.pth"))


    def test(self, test_set, test_batch_size, **kwargs):

        labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr = [], [], [], [], [], [], []

        for iter_test in range(0, len(test_set["can"]), test_batch_size):
            batch_set = get_minibatch(test_set, iter_test, test_batch_size, batch_type="test")
            batch_preds, batch_hid_tmr, batch_emb_tmr, batch_emb_sga, batch_attn_wt = self.forward(
                batch_set["sga"].to(device), batch_set["can"].to(device))
            batch_labels = batch_set["deg"].to(device)

            labels.append(batch_labels.data.to(torch.device("cpu")).numpy())
            preds.append(batch_preds.data.to(torch.device("cpu")).numpy())
            hid_tmr.append(batch_hid_tmr.data.to(torch.device("cpu")).numpy())
            emb_tmr.append(batch_emb_tmr.data.to(torch.device("cpu")).numpy())
            emb_sga.append(batch_emb_sga.data.to(torch.device("cpu")).numpy())
            attn_wt.append(batch_attn_wt.data.to(torch.device("cpu")).numpy())
            tmr = tmr + batch_set["tmr"]

        labels = np.concatenate(labels,axis=0)
        preds = np.concatenate(preds,axis=0)
        hid_tmr = np.concatenate(hid_tmr,axis=0)
        emb_tmr = np.concatenate(emb_tmr,axis=0)
        emb_sga = np.concatenate(emb_sga,axis=0)
        attn_wt = np.concatenate(attn_wt,axis=0)

        return labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr


### Non Binary Target

In [7]:
# Parse arguments
args_nb = SimpleNamespace()

args_nb.train_model=True

args_nb.input_dir="data_noBin"
args_nb.output_dir="data_noBin"

args_nb.embedding_size=512
args_nb.hidden_size=1024
args_nb.attention_size=400
args_nb.attention_head=128

args_nb.max_fscore=0.7
args_nb.batch_size=16
args_nb.test_batch_size=512
args_nb.test_inc_size=256
args_nb.dropout_rate=0.5
args_nb.weight_decay=1e-5

args_nb.deg_shuffle=False
args_nb.nonbinary=True

# Load data
dataset_nb = load_dataset(input_dir=args_nb.input_dir, deg_shuffle=args_nb.deg_shuffle)
train_set_nb, test_set_nb = split_dataset(dataset_nb, ratio=0.66)

args_nb.can_size = dataset_nb["can"].max()        # cancer type dimension
args_nb.sga_size = max(dataset_nb["sga"].max(), 19781)        # SGA dimension
args_nb.deg_size = dataset_nb["deg"].shape[1]     # DEG output dimension
args_nb.num_max_sga = dataset_nb["sga"].shape[1]  # maximum number of SGAs in a tumor

In [12]:
precision_nb_GIT, recall_nb_GIT, f1score_nb_GIT, accuracy_nb_GIT = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args_nb.initializtion=True
# args.attention=False -> GIT-attn
args_nb.attention=True
# args.cancer_type=False -> GIT-can
args_nb.cancer_type=True

args_nb.max_iter=3072*20
args_nb.learning_rate=1e-4

if args_nb.cancer_type == False:
    args_nb.max_iter = 3072*40
elif args_nb.attention == False:
    args_nb.max_iter = 3072*40
    args_nb.learning_rate = 0.0003

for i in range(5):
    
    # Init model with single hidden layer
    model = GIT(args_nb).to(device)

    # Train MLP model
    model.train(train_set_nb, test_set_nb,
          batch_size=args_nb.batch_size,
          test_batch_size=args_nb.test_batch_size,
          max_iter=args_nb.max_iter,
          max_fscore=args_nb.max_fscore,
          test_inc_size=args_nb.test_inc_size)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set_nb, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_nb_GIT.append(precision)
    recall_nb_GIT.append(recall)
    f1score_nb_GIT.append(f1score)
    accuracy_nb_GIT.append(accuracy)



[0,0], precision: 0.172, acc: 0.227
[0,256], precision: 0.294, acc: 0.313
[0,512], precision: 0.309, acc: 0.315
[0,768], precision: 0.310, acc: 0.313
[0,1024], precision: 0.310, acc: 0.313
[0,1280], precision: 0.310, acc: 0.313
[0,1536], precision: 0.311, acc: 0.313
[0,1792], precision: 0.311, acc: 0.314
[0,2048], precision: 0.311, acc: 0.314
[0,2304], precision: 0.311, acc: 0.314
[0,2560], precision: 0.311, acc: 0.314
[0,2816], precision: 0.311, acc: 0.314
[1,124], precision: 0.311, acc: 0.315
[1,380], precision: 0.311, acc: 0.315
[1,636], precision: 0.311, acc: 0.316
[1,892], precision: 0.311, acc: 0.317
[1,1148], precision: 0.311, acc: 0.318
[1,1404], precision: 0.312, acc: 0.319
[1,1660], precision: 0.312, acc: 0.320
[1,1916], precision: 0.312, acc: 0.321
[1,2172], precision: 0.312, acc: 0.322
[1,2428], precision: 0.312, acc: 0.324
[1,2684], precision: 0.313, acc: 0.326
[1,2940], precision: 0.313, acc: 0.329
[2,248], precision: 0.313, acc: 0.332
[2,504], precision: 0.313, acc: 0.33

[18,696], precision: 0.684, acc: 0.771
[18,952], precision: 0.680, acc: 0.771
[18,1208], precision: 0.669, acc: 0.770
[18,1464], precision: 0.690, acc: 0.773
[18,1720], precision: 0.685, acc: 0.772
[18,1976], precision: 0.676, acc: 0.771
[18,2232], precision: 0.679, acc: 0.772
[18,2488], precision: 0.660, acc: 0.769
[18,2744], precision: 0.676, acc: 0.772
[19,52], precision: 0.680, acc: 0.772
[19,308], precision: 0.674, acc: 0.771
[19,564], precision: 0.675, acc: 0.770
[19,820], precision: 0.678, acc: 0.771
[19,1076], precision: 0.674, acc: 0.771
[19,1332], precision: 0.672, acc: 0.771
[19,1588], precision: 0.694, acc: 0.773
[19,1844], precision: 0.688, acc: 0.773
[19,2100], precision: 0.668, acc: 0.770
[19,2356], precision: 0.677, acc: 0.771
[19,2612], precision: 0.667, acc: 0.771
[19,2868], precision: 0.684, acc: 0.773
[20,176], precision: 0.688, acc: 0.773
[20,432], precision: 0.685, acc: 0.773
[20,688], precision: 0.684, acc: 0.772
[20,944], precision: 0.684, acc: 0.772
[20,1200], 

[36,624], precision: 0.680, acc: 0.769
[36,880], precision: 0.667, acc: 0.767
[36,1136], precision: 0.665, acc: 0.767
[36,1392], precision: 0.672, acc: 0.768
[36,1648], precision: 0.670, acc: 0.767
[36,1904], precision: 0.681, acc: 0.768
[36,2160], precision: 0.686, acc: 0.770
[36,2416], precision: 0.675, acc: 0.769
[36,2672], precision: 0.659, acc: 0.766
[36,2928], precision: 0.675, acc: 0.768
[37,236], precision: 0.689, acc: 0.770
[37,492], precision: 0.667, acc: 0.767
[37,748], precision: 0.683, acc: 0.769
[37,1004], precision: 0.661, acc: 0.766
[37,1260], precision: 0.677, acc: 0.769
[37,1516], precision: 0.669, acc: 0.767
[37,1772], precision: 0.680, acc: 0.769
[37,2028], precision: 0.679, acc: 0.769
[37,2284], precision: 0.667, acc: 0.768
[37,2540], precision: 0.685, acc: 0.771
[37,2796], precision: 0.652, acc: 0.764
[38,104], precision: 0.682, acc: 0.770
[38,360], precision: 0.667, acc: 0.766
[38,616], precision: 0.678, acc: 0.769
[38,872], precision: 0.678, acc: 0.769
[38,1128]

[12,1744], precision: 0.663, acc: 0.766
[12,2000], precision: 0.636, acc: 0.760
[12,2256], precision: 0.656, acc: 0.764
[12,2512], precision: 0.638, acc: 0.761
[12,2768], precision: 0.645, acc: 0.763
[13,76], precision: 0.655, acc: 0.765
[13,332], precision: 0.658, acc: 0.766
[13,588], precision: 0.658, acc: 0.766
[13,844], precision: 0.658, acc: 0.767
[13,1100], precision: 0.655, acc: 0.766
[13,1356], precision: 0.671, acc: 0.770
[13,1612], precision: 0.663, acc: 0.768
[13,1868], precision: 0.668, acc: 0.769
[13,2124], precision: 0.654, acc: 0.766
[13,2380], precision: 0.660, acc: 0.767
[13,2636], precision: 0.653, acc: 0.766
[13,2892], precision: 0.662, acc: 0.769
[14,200], precision: 0.666, acc: 0.769
[14,456], precision: 0.666, acc: 0.769
[14,712], precision: 0.679, acc: 0.771
[14,968], precision: 0.660, acc: 0.767
[14,1224], precision: 0.653, acc: 0.766
[14,1480], precision: 0.684, acc: 0.771
[14,1736], precision: 0.673, acc: 0.770
[14,1992], precision: 0.666, acc: 0.769
[14,2248]

[30,1672], precision: 0.680, acc: 0.770
[30,1928], precision: 0.684, acc: 0.771
[30,2184], precision: 0.677, acc: 0.770
[30,2440], precision: 0.654, acc: 0.766
[30,2696], precision: 0.672, acc: 0.769
[31,4], precision: 0.675, acc: 0.769
[31,260], precision: 0.668, acc: 0.768
[31,516], precision: 0.682, acc: 0.770
[31,772], precision: 0.672, acc: 0.769
[31,1028], precision: 0.669, acc: 0.769
[31,1284], precision: 0.674, acc: 0.770
[31,1540], precision: 0.681, acc: 0.770
[31,1796], precision: 0.685, acc: 0.771
[31,2052], precision: 0.683, acc: 0.771
[31,2308], precision: 0.678, acc: 0.771
[31,2564], precision: 0.668, acc: 0.769
[31,2820], precision: 0.665, acc: 0.768
[32,128], precision: 0.681, acc: 0.770
[32,384], precision: 0.671, acc: 0.768
[32,640], precision: 0.688, acc: 0.772
[32,896], precision: 0.673, acc: 0.769
[32,1152], precision: 0.674, acc: 0.770
[32,1408], precision: 0.677, acc: 0.770
[32,1664], precision: 0.678, acc: 0.771
[32,1920], precision: 0.694, acc: 0.773
[32,2176],

[6,2536], precision: 0.365, acc: 0.508
[6,2792], precision: 0.367, acc: 0.509
[7,100], precision: 0.366, acc: 0.524
[7,356], precision: 0.370, acc: 0.521
[7,612], precision: 0.373, acc: 0.531
[7,868], precision: 0.376, acc: 0.535
[7,1124], precision: 0.379, acc: 0.544
[7,1380], precision: 0.381, acc: 0.545
[7,1636], precision: 0.384, acc: 0.556
[7,1892], precision: 0.390, acc: 0.558
[7,2148], precision: 0.393, acc: 0.569
[7,2404], precision: 0.396, acc: 0.570
[7,2660], precision: 0.399, acc: 0.575
[7,2916], precision: 0.403, acc: 0.589
[8,224], precision: 0.405, acc: 0.592
[8,480], precision: 0.408, acc: 0.590
[8,736], precision: 0.412, acc: 0.601
[8,992], precision: 0.419, acc: 0.604
[8,1248], precision: 0.425, acc: 0.618
[8,1504], precision: 0.430, acc: 0.623
[8,1760], precision: 0.436, acc: 0.629
[8,2016], precision: 0.439, acc: 0.630
[8,2272], precision: 0.449, acc: 0.644
[8,2528], precision: 0.450, acc: 0.643
[8,2784], precision: 0.457, acc: 0.652
[9,92], precision: 0.462, acc: 0.

[24,2720], precision: 0.669, acc: 0.770
[25,28], precision: 0.686, acc: 0.772
[25,284], precision: 0.682, acc: 0.771
[25,540], precision: 0.678, acc: 0.771
[25,796], precision: 0.682, acc: 0.772
[25,1052], precision: 0.667, acc: 0.769
[25,1308], precision: 0.688, acc: 0.772
[25,1564], precision: 0.669, acc: 0.769
[25,1820], precision: 0.694, acc: 0.772
[25,2076], precision: 0.677, acc: 0.771
[25,2332], precision: 0.666, acc: 0.769
[25,2588], precision: 0.681, acc: 0.772
[25,2844], precision: 0.672, acc: 0.770
[26,152], precision: 0.683, acc: 0.771
[26,408], precision: 0.679, acc: 0.770
[26,664], precision: 0.684, acc: 0.772
[26,920], precision: 0.673, acc: 0.770
[26,1176], precision: 0.676, acc: 0.771
[26,1432], precision: 0.693, acc: 0.773
[26,1688], precision: 0.674, acc: 0.770
[26,1944], precision: 0.692, acc: 0.773
[26,2200], precision: 0.661, acc: 0.767
[26,2456], precision: 0.676, acc: 0.771
[26,2712], precision: 0.669, acc: 0.769
[27,20], precision: 0.685, acc: 0.771
[27,276], p

[1,124], precision: 0.311, acc: 0.315
[1,380], precision: 0.311, acc: 0.316
[1,636], precision: 0.311, acc: 0.317
[1,892], precision: 0.311, acc: 0.318
[1,1148], precision: 0.311, acc: 0.319
[1,1404], precision: 0.311, acc: 0.319
[1,1660], precision: 0.312, acc: 0.321
[1,1916], precision: 0.312, acc: 0.322
[1,2172], precision: 0.312, acc: 0.323
[1,2428], precision: 0.312, acc: 0.326
[1,2684], precision: 0.313, acc: 0.328
[1,2940], precision: 0.313, acc: 0.331
[2,248], precision: 0.313, acc: 0.334
[2,504], precision: 0.314, acc: 0.336
[2,760], precision: 0.314, acc: 0.337
[2,1016], precision: 0.315, acc: 0.339
[2,1272], precision: 0.315, acc: 0.341
[2,1528], precision: 0.315, acc: 0.344
[2,1784], precision: 0.316, acc: 0.348
[2,2040], precision: 0.316, acc: 0.348
[2,2296], precision: 0.317, acc: 0.351
[2,2552], precision: 0.317, acc: 0.350
[2,2808], precision: 0.317, acc: 0.356
[3,116], precision: 0.317, acc: 0.357
[3,372], precision: 0.318, acc: 0.357
[3,628], precision: 0.319, acc: 0.

[19,820], precision: 0.678, acc: 0.771
[19,1076], precision: 0.667, acc: 0.770
[19,1332], precision: 0.676, acc: 0.771
[19,1588], precision: 0.686, acc: 0.772
[19,1844], precision: 0.694, acc: 0.773
[19,2100], precision: 0.663, acc: 0.769
[19,2356], precision: 0.681, acc: 0.772
[19,2612], precision: 0.661, acc: 0.769
[19,2868], precision: 0.682, acc: 0.772
[20,176], precision: 0.688, acc: 0.772
[20,432], precision: 0.675, acc: 0.771
[20,688], precision: 0.692, acc: 0.772
[20,944], precision: 0.673, acc: 0.771
[20,1200], precision: 0.678, acc: 0.772
[20,1456], precision: 0.681, acc: 0.771
[20,1712], precision: 0.689, acc: 0.771
[20,1968], precision: 0.679, acc: 0.771
[20,2224], precision: 0.678, acc: 0.771
[20,2480], precision: 0.675, acc: 0.771
[20,2736], precision: 0.658, acc: 0.768
[21,44], precision: 0.685, acc: 0.773
[21,300], precision: 0.669, acc: 0.769
[21,556], precision: 0.677, acc: 0.770
[21,812], precision: 0.687, acc: 0.772
[21,1068], precision: 0.666, acc: 0.769
[21,1324],

[37,748], precision: 0.674, acc: 0.768
[37,1004], precision: 0.676, acc: 0.769
[37,1260], precision: 0.672, acc: 0.768
[37,1516], precision: 0.693, acc: 0.770
[37,1772], precision: 0.676, acc: 0.769
[37,2028], precision: 0.681, acc: 0.769
[37,2284], precision: 0.687, acc: 0.770
[37,2540], precision: 0.655, acc: 0.765
[37,2796], precision: 0.671, acc: 0.768
[38,104], precision: 0.681, acc: 0.769
[38,360], precision: 0.660, acc: 0.766
[38,616], precision: 0.676, acc: 0.767
[38,872], precision: 0.665, acc: 0.766
[38,1128], precision: 0.668, acc: 0.767
[38,1384], precision: 0.674, acc: 0.768
[38,1640], precision: 0.680, acc: 0.769
[38,1896], precision: 0.670, acc: 0.767
[38,2152], precision: 0.688, acc: 0.771
[38,2408], precision: 0.671, acc: 0.768
[38,2664], precision: 0.646, acc: 0.763
[38,2920], precision: 0.679, acc: 0.768
[39,228], precision: 0.675, acc: 0.769
[39,484], precision: 0.665, acc: 0.766
[39,740], precision: 0.689, acc: 0.770
[39,996], precision: 0.665, acc: 0.766
[39,1252]

[13,1868], precision: 0.656, acc: 0.766
[13,2124], precision: 0.670, acc: 0.769
[13,2380], precision: 0.654, acc: 0.766
[13,2636], precision: 0.650, acc: 0.766
[13,2892], precision: 0.673, acc: 0.770
[14,200], precision: 0.674, acc: 0.770
[14,456], precision: 0.659, acc: 0.767
[14,712], precision: 0.671, acc: 0.769
[14,968], precision: 0.669, acc: 0.769
[14,1224], precision: 0.665, acc: 0.769
[14,1480], precision: 0.676, acc: 0.770
[14,1736], precision: 0.672, acc: 0.768
[14,1992], precision: 0.667, acc: 0.770
[14,2248], precision: 0.675, acc: 0.771
[14,2504], precision: 0.658, acc: 0.768
[14,2760], precision: 0.652, acc: 0.766
[15,68], precision: 0.680, acc: 0.771
[15,324], precision: 0.666, acc: 0.768
[15,580], precision: 0.673, acc: 0.770
[15,836], precision: 0.667, acc: 0.769
[15,1092], precision: 0.660, acc: 0.767
[15,1348], precision: 0.675, acc: 0.770
[15,1604], precision: 0.683, acc: 0.772
[15,1860], precision: 0.680, acc: 0.772
[15,2116], precision: 0.664, acc: 0.769
[15,2372]

[31,1796], precision: 0.686, acc: 0.771
[31,2052], precision: 0.693, acc: 0.772
[31,2308], precision: 0.664, acc: 0.767
[31,2564], precision: 0.673, acc: 0.770
[31,2820], precision: 0.662, acc: 0.767
[32,128], precision: 0.684, acc: 0.770
[32,384], precision: 0.675, acc: 0.770
[32,640], precision: 0.680, acc: 0.770
[32,896], precision: 0.682, acc: 0.770
[32,1152], precision: 0.667, acc: 0.768
[32,1408], precision: 0.671, acc: 0.769
[32,1664], precision: 0.691, acc: 0.771
[32,1920], precision: 0.679, acc: 0.770
[32,2176], precision: 0.682, acc: 0.770
[32,2432], precision: 0.653, acc: 0.764
[32,2688], precision: 0.675, acc: 0.770
[32,2944], precision: 0.660, acc: 0.766
[33,252], precision: 0.682, acc: 0.770
[33,508], precision: 0.677, acc: 0.769
[33,764], precision: 0.684, acc: 0.770
[33,1020], precision: 0.673, acc: 0.769
[33,1276], precision: 0.660, acc: 0.766
[33,1532], precision: 0.675, acc: 0.768
[33,1788], precision: 0.688, acc: 0.770
[33,2044], precision: 0.679, acc: 0.770
[33,230

In [13]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_nb_GIT), sta.mean(recall_nb_GIT), sta.mean(f1score_nb_GIT), sta.mean(accuracy_nb_GIT)))

prec=0.677, recall=0.501, F1=0.575, acc=0.768


In [None]:
#prec=0.696, recall=0.537, F1=0.606, acc=0.781


In [16]:
np.around(preds)

array([[-0., -0., -0., ..., -0., -0.,  1.],
       [ 1., -0., -0., ..., -0., -1., -0.],
       [-0., -0.,  1., ..., -0., -1.,  1.],
       ...,
       [-1., -0., -1., ..., -0., -0., -0.],
       [-0.,  0.,  0., ...,  1.,  0., -0.],
       [ 0.,  1.,  0., ...,  0., -0.,  1.]], dtype=float32)