In [3]:
import os
import argparse
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from utils import *
from base import ModelBase
import statistics as sta
from types import SimpleNamespace

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class GIT(ModelBase):

    def __init__(self, args, **kwargs):

        super(GIT, self).__init__(args, **kwargs)
        self.build()

    def build(self):

        self.layer_sga_emb = nn.Embedding(
            num_embeddings=self.sga_size+1,
            embedding_dim=self.embedding_size,
            padding_idx=0)

        self.layer_can_emb = nn.Embedding(
            num_embeddings=self.can_size+1,
            embedding_dim=self.embedding_size,
            padding_idx=0)

        self.layer_w_0 = nn.Linear(
            in_features=self.embedding_size,
            out_features=self.attention_size,
            bias=True)

        self.layer_beta = nn.Linear(
            in_features=self.attention_size,
            out_features=self.attention_head,
            bias=True)

        self.layer_dropout_1 = nn.Dropout(p=self.dropout_rate)

        self.layer_w_1 = nn.Linear(
            in_features=self.embedding_size,
            out_features=self.hidden_size,
            bias=True)

        self.layer_dropout_2 = nn.Dropout(p=self.dropout_rate)

        self.layer_w_2 = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.deg_size,
            bias=True)

        if self.initializtion:
            gene_emb_pretrain = np.load(os.path.join(self.input_dir, "gene_emb_pretrain.npy"))
            self.layer_sga_emb.weight.data.copy_(torch.from_numpy(gene_emb_pretrain))

        self.optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay)


    def forward(self, sga_index, can_index):
   
        # cancer type embedding
        emb_can = self.layer_can_emb(can_index)
        emb_can = emb_can.view(-1,self.embedding_size)

        # gene embedings
        E_t = self.layer_sga_emb(sga_index)

        # squeeze and tanh-curve the gene embeddings
        E_t_flatten = E_t.view(-1, self.embedding_size)
        E_t1_flatten = torch.tanh( self.layer_w_0(E_t_flatten) )

        # multiplied by attention heads
        E_t2_flatten = self.layer_beta(E_t1_flatten)
        E_t2 = E_t2_flatten.view(-1, self.num_max_sga, self.attention_head)

        # normalize by softmax
        E_t2 = E_t2.permute(1,0,2)
        A = F.softmax(E_t2)
        A = A.permute(1,0,2)

        if self.attention:
          # multi-head attention weighted sga embedding:
            emb_sga = torch.sum( torch.bmm( A.permute(0,2,1), E_t ), dim=1)
            emb_sga = emb_sga.view(-1,self.embedding_size)
        else:
          # if not using attention, simply sum up SGA embeddings
            emb_sga = torch.sum(E_t, dim=1)
            emb_sga = emb_sga.view(-1, self.embedding_size)

        # if use cancer type input, add cancer type embedding
        if self.cancer_type:
            emb_tmr = emb_can+emb_sga
        else:
            emb_tmr = emb_sga

        # MLP decoder
        emb_tmr_relu = self.layer_dropout_1(F.relu(emb_tmr))
        hid_tmr = self.layer_w_1(emb_tmr_relu)
        hid_tmr_relu = self.layer_dropout_2(F.relu(hid_tmr))
        

        preds = F.sigmoid(self.layer_w_2(hid_tmr_relu))

        # attention weights
        attn_wt = torch.sum(A, dim=2)
        attn_wt = attn_wt.view(-1, self.num_max_sga)

        return preds, hid_tmr, emb_tmr, emb_sga, attn_wt


    def train(self, train_set, test_set,
            batch_size=None, test_batch_size=None,
            max_iter=None, max_fscore=None,
            test_inc_size=None, **kwargs):

        for iter_train in range(0, max_iter+1, batch_size):
            batch_set = get_minibatch(train_set, iter_train, batch_size,batch_type="train")
            preds, _, _, _, _ = self.forward(batch_set["sga"].to(device), batch_set["can"].to(device))
            labels = batch_set["deg"].to(device)

            self.optimizer.zero_grad()
            loss = -torch.log( self.epsilon + torch.abs(1 - torch.abs(preds - labels)) ).mean()
            loss.backward()
            self.optimizer.step()

            if test_inc_size and (iter_train % test_inc_size == 0):
                labels, preds, _, _, _, _, _ = self.test(test_set, test_batch_size)
                precision, recall, f1score, accuracy = evaluate(
                    labels, preds, epsilon=self.epsilon)
                print("[%d,%d], precision: %.3f, acc: %.3f"% (iter_train//len(train_set["can"]),
                      iter_train%len(train_set["can"]), precision, accuracy))

                if f1score >= max_fscore:
                    break

        #self.save_model(os.path.join(self.output_dir, "trained_model.pth"))


    def test(self, test_set, test_batch_size, **kwargs):

        labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr = [], [], [], [], [], [], []

        for iter_test in range(0, len(test_set["can"]), test_batch_size):
            batch_set = get_minibatch(test_set, iter_test, test_batch_size, batch_type="test")
            batch_preds, batch_hid_tmr, batch_emb_tmr, batch_emb_sga, batch_attn_wt = self.forward(
                batch_set["sga"].to(device), batch_set["can"].to(device))
            batch_labels = batch_set["deg"].to(device)

            labels.append(batch_labels.data.to(torch.device("cpu")).numpy())
            preds.append(batch_preds.data.to(torch.device("cpu")).numpy())
            hid_tmr.append(batch_hid_tmr.data.to(torch.device("cpu")).numpy())
            emb_tmr.append(batch_emb_tmr.data.to(torch.device("cpu")).numpy())
            emb_sga.append(batch_emb_sga.data.to(torch.device("cpu")).numpy())
            attn_wt.append(batch_attn_wt.data.to(torch.device("cpu")).numpy())
            tmr = tmr + batch_set["tmr"]

        labels = np.concatenate(labels,axis=0)
        preds = np.concatenate(preds,axis=0)
        hid_tmr = np.concatenate(hid_tmr,axis=0)
        emb_tmr = np.concatenate(emb_tmr,axis=0)
        emb_sga = np.concatenate(emb_sga,axis=0)
        attn_wt = np.concatenate(attn_wt,axis=0)

        return labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr


In [6]:
# Parse arguments
args = SimpleNamespace()

args.train_model=True

args.input_dir="data"
args.output_dir="data"

args.embedding_size=512
args.hidden_size=1024
args.attention_size=400
args.attention_head=128

args.max_fscore=0.7
args.batch_size=16
args.test_batch_size=512
args.test_inc_size=256
args.dropout_rate=0.5
args.weight_decay=1e-5

# GIT variants:
# args.deg_shuffle=False -> DEG-shuffled
args.deg_shuffle=False

# Load data
dataset = load_dataset(input_dir="data", deg_shuffle=False)
train_set, test_set = split_dataset(dataset, ratio=0.66)

args.can_size = dataset["can"].max()        # cancer type dimension
args.sga_size = dataset["sga"].max()        # SGA dimension
args.deg_size = dataset["deg"].shape[1]     # DEG output dimension
args.num_max_sga = dataset["sga"].shape[1]  # maximum number of SGAs in a tumor

### Full GIT Model

In [9]:
precision_GIT, recall_GIT, f1score_GIT, accuracy_GIT = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args.initializtion=True
# args.attention=False -> GIT-attn
args.attention=True
# args.cancer_type=False -> GIT-can
args.cancer_type=True

args.max_iter=3072*20
args.learning_rate=1e-4

for i in range(5):
    
    # Init model with single hidden layer
    model = GIT(args).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=args.batch_size,
          test_batch_size=args.test_batch_size,
          max_iter=args.max_iter,
          max_fscore=args.max_fscore,
          test_inc_size=args.test_inc_size)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_GIT.append(precision)
    recall_GIT.append(recall)
    f1score_GIT.append(f1score)
    accuracy_GIT.append(accuracy)



[0,0], precision: 0.313, acc: 0.512
[0,256], precision: 0.361, acc: 0.598
[0,512], precision: 0.402, acc: 0.639
[0,768], precision: 0.434, acc: 0.660
[0,1024], precision: 0.457, acc: 0.673
[0,1280], precision: 0.476, acc: 0.682
[0,1536], precision: 0.498, acc: 0.691
[0,1792], precision: 0.519, acc: 0.698
[0,2048], precision: 0.534, acc: 0.703
[0,2304], precision: 0.550, acc: 0.707
[0,2560], precision: 0.562, acc: 0.711
[0,2816], precision: 0.579, acc: 0.715
[1,124], precision: 0.599, acc: 0.719
[1,380], precision: 0.611, acc: 0.723
[1,636], precision: 0.626, acc: 0.727
[1,892], precision: 0.635, acc: 0.730
[1,1148], precision: 0.641, acc: 0.734
[1,1404], precision: 0.654, acc: 0.737
[1,1660], precision: 0.663, acc: 0.740
[1,1916], precision: 0.670, acc: 0.743
[1,2172], precision: 0.674, acc: 0.746
[1,2428], precision: 0.673, acc: 0.748
[1,2684], precision: 0.676, acc: 0.751
[1,2940], precision: 0.684, acc: 0.753
[2,248], precision: 0.689, acc: 0.755
[2,504], precision: 0.690, acc: 0.75

[18,696], precision: 0.706, acc: 0.791
[18,952], precision: 0.702, acc: 0.791
[18,1208], precision: 0.698, acc: 0.791
[18,1464], precision: 0.706, acc: 0.792
[18,1720], precision: 0.712, acc: 0.792
[18,1976], precision: 0.705, acc: 0.791
[18,2232], precision: 0.704, acc: 0.792
[18,2488], precision: 0.694, acc: 0.791
[18,2744], precision: 0.702, acc: 0.791
[19,52], precision: 0.705, acc: 0.791
[19,308], precision: 0.701, acc: 0.791
[19,564], precision: 0.700, acc: 0.790
[19,820], precision: 0.704, acc: 0.791
[19,1076], precision: 0.705, acc: 0.791
[19,1332], precision: 0.702, acc: 0.791
[19,1588], precision: 0.704, acc: 0.791
[19,1844], precision: 0.708, acc: 0.791
[19,2100], precision: 0.701, acc: 0.791
[19,2356], precision: 0.693, acc: 0.791
[19,2612], precision: 0.692, acc: 0.790
[19,2868], precision: 0.703, acc: 0.791
[20,176], precision: 0.707, acc: 0.791
[20,432], precision: 0.700, acc: 0.791
[20,688], precision: 0.706, acc: 0.791
[20,944], precision: 0.705, acc: 0.791
[20,1200], 

[15,1348], precision: 0.703, acc: 0.791
[15,1604], precision: 0.707, acc: 0.792
[15,1860], precision: 0.708, acc: 0.791
[15,2116], precision: 0.706, acc: 0.791
[15,2372], precision: 0.701, acc: 0.791
[15,2628], precision: 0.700, acc: 0.791
[15,2884], precision: 0.703, acc: 0.792
[16,192], precision: 0.707, acc: 0.792
[16,448], precision: 0.704, acc: 0.792
[16,704], precision: 0.707, acc: 0.792
[16,960], precision: 0.705, acc: 0.791
[16,1216], precision: 0.700, acc: 0.791
[16,1472], precision: 0.706, acc: 0.792
[16,1728], precision: 0.709, acc: 0.791
[16,1984], precision: 0.707, acc: 0.792
[16,2240], precision: 0.704, acc: 0.792
[16,2496], precision: 0.699, acc: 0.792
[16,2752], precision: 0.703, acc: 0.792
[17,60], precision: 0.709, acc: 0.792
[17,316], precision: 0.703, acc: 0.792
[17,572], precision: 0.704, acc: 0.792
[17,828], precision: 0.708, acc: 0.792
[17,1084], precision: 0.704, acc: 0.792
[17,1340], precision: 0.704, acc: 0.792
[17,1596], precision: 0.707, acc: 0.792
[17,1852]

[12,2000], precision: 0.705, acc: 0.789
[12,2256], precision: 0.702, acc: 0.790
[12,2512], precision: 0.700, acc: 0.790
[12,2768], precision: 0.700, acc: 0.790
[13,76], precision: 0.705, acc: 0.790
[13,332], precision: 0.702, acc: 0.790
[13,588], precision: 0.701, acc: 0.790
[13,844], precision: 0.704, acc: 0.790
[13,1100], precision: 0.703, acc: 0.790
[13,1356], precision: 0.703, acc: 0.790
[13,1612], precision: 0.708, acc: 0.791
[13,1868], precision: 0.708, acc: 0.790
[13,2124], precision: 0.703, acc: 0.791
[13,2380], precision: 0.697, acc: 0.790
[13,2636], precision: 0.698, acc: 0.791
[13,2892], precision: 0.702, acc: 0.791
[14,200], precision: 0.705, acc: 0.791
[14,456], precision: 0.703, acc: 0.790
[14,712], precision: 0.706, acc: 0.791
[14,968], precision: 0.705, acc: 0.790
[14,1224], precision: 0.699, acc: 0.790
[14,1480], precision: 0.702, acc: 0.791
[14,1736], precision: 0.709, acc: 0.791
[14,1992], precision: 0.706, acc: 0.791
[14,2248], precision: 0.703, acc: 0.791
[14,2504]

[9,2652], precision: 0.693, acc: 0.788
[9,2908], precision: 0.701, acc: 0.788
[10,216], precision: 0.701, acc: 0.788
[10,472], precision: 0.700, acc: 0.789
[10,728], precision: 0.702, acc: 0.789
[10,984], precision: 0.700, acc: 0.788
[10,1240], precision: 0.697, acc: 0.788
[10,1496], precision: 0.701, acc: 0.789
[10,1752], precision: 0.704, acc: 0.789
[10,2008], precision: 0.701, acc: 0.789
[10,2264], precision: 0.699, acc: 0.789
[10,2520], precision: 0.697, acc: 0.789
[10,2776], precision: 0.697, acc: 0.789
[11,84], precision: 0.701, acc: 0.789
[11,340], precision: 0.701, acc: 0.789
[11,596], precision: 0.699, acc: 0.789
[11,852], precision: 0.704, acc: 0.789
[11,1108], precision: 0.700, acc: 0.789
[11,1364], precision: 0.700, acc: 0.789
[11,1620], precision: 0.705, acc: 0.790
[11,1876], precision: 0.707, acc: 0.790
[11,2132], precision: 0.702, acc: 0.789
[11,2388], precision: 0.699, acc: 0.789
[11,2644], precision: 0.697, acc: 0.789
[11,2900], precision: 0.701, acc: 0.790
[12,208], p

[7,100], precision: 0.700, acc: 0.786
[7,356], precision: 0.698, acc: 0.786
[7,612], precision: 0.697, acc: 0.786
[7,868], precision: 0.699, acc: 0.786
[7,1124], precision: 0.696, acc: 0.787
[7,1380], precision: 0.696, acc: 0.786
[7,1636], precision: 0.699, acc: 0.787
[7,1892], precision: 0.701, acc: 0.787
[7,2148], precision: 0.699, acc: 0.787
[7,2404], precision: 0.696, acc: 0.786
[7,2660], precision: 0.697, acc: 0.787
[7,2916], precision: 0.699, acc: 0.787
[8,224], precision: 0.700, acc: 0.787
[8,480], precision: 0.699, acc: 0.787
[8,736], precision: 0.700, acc: 0.787
[8,992], precision: 0.697, acc: 0.787
[8,1248], precision: 0.694, acc: 0.787
[8,1504], precision: 0.699, acc: 0.787
[8,1760], precision: 0.703, acc: 0.787
[8,2016], precision: 0.702, acc: 0.787
[8,2272], precision: 0.698, acc: 0.787
[8,2528], precision: 0.694, acc: 0.787
[8,2784], precision: 0.695, acc: 0.787
[9,92], precision: 0.699, acc: 0.788
[9,348], precision: 0.699, acc: 0.788
[9,604], precision: 0.698, acc: 0.78

In [10]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_GIT), sta.mean(recall_GIT), sta.mean(f1score_GIT), sta.mean(accuracy_GIT)))

prec=0.689, recall=0.580, F1=0.630, acc=0.790


### Less Initializtion GIT Model

In [7]:
precision_LI, recall_LI, f1score_LI, accuracy_LI = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args.initializtion=False
# args.attention=False -> GIT-attn
args.attention=True
# args.cancer_type=False -> GIT-can
args.cancer_type=True

args.max_iter = 3072*20
args.learning_rate = 1e-4

for i in range(5):
    
    # Init model with single hidden layer
    torch.cuda.empty_cache()
    model = GIT(args).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=args.batch_size,
          test_batch_size=args.test_batch_size,
          max_iter=args.max_iter,
          max_fscore=args.max_fscore,
          test_inc_size=args.test_inc_size)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_LI.append(precision)
    recall_LI.append(recall)
    f1score_LI.append(f1score)
    accuracy_LI.append(accuracy)

  dataset = {"sga": Variable(torch.LongTensor(sga)),


[0,0], precision: 0.310, acc: 0.505
[0,256], precision: 0.362, acc: 0.601
[0,512], precision: 0.413, acc: 0.648
[0,768], precision: 0.449, acc: 0.670
[0,1024], precision: 0.469, acc: 0.680
[0,1280], precision: 0.480, acc: 0.684
[0,1536], precision: 0.492, acc: 0.689
[0,1792], precision: 0.500, acc: 0.692
[0,2048], precision: 0.504, acc: 0.693
[0,2304], precision: 0.511, acc: 0.695
[0,2560], precision: 0.517, acc: 0.697
[0,2816], precision: 0.523, acc: 0.699
[1,124], precision: 0.534, acc: 0.702
[1,380], precision: 0.540, acc: 0.704
[1,636], precision: 0.544, acc: 0.705
[1,892], precision: 0.551, acc: 0.708
[1,1148], precision: 0.553, acc: 0.709
[1,1404], precision: 0.561, acc: 0.711
[1,1660], precision: 0.565, acc: 0.712
[1,1916], precision: 0.573, acc: 0.714
[1,2172], precision: 0.575, acc: 0.715
[1,2428], precision: 0.581, acc: 0.717
[1,2684], precision: 0.584, acc: 0.719
[1,2940], precision: 0.594, acc: 0.720
[2,248], precision: 0.600, acc: 0.722
[2,504], precision: 0.601, acc: 0.72

[18,696], precision: 0.696, acc: 0.784
[18,952], precision: 0.692, acc: 0.783
[18,1208], precision: 0.692, acc: 0.784
[18,1464], precision: 0.699, acc: 0.784
[18,1720], precision: 0.703, acc: 0.784
[18,1976], precision: 0.701, acc: 0.785
[18,2232], precision: 0.696, acc: 0.784
[18,2488], precision: 0.689, acc: 0.783
[18,2744], precision: 0.690, acc: 0.784
[19,52], precision: 0.698, acc: 0.784
[19,308], precision: 0.700, acc: 0.785
[19,564], precision: 0.697, acc: 0.785
[19,820], precision: 0.699, acc: 0.784
[19,1076], precision: 0.695, acc: 0.785
[19,1332], precision: 0.696, acc: 0.784
[19,1588], precision: 0.699, acc: 0.784
[19,1844], precision: 0.704, acc: 0.784
[19,2100], precision: 0.700, acc: 0.785
[19,2356], precision: 0.695, acc: 0.784
[19,2612], precision: 0.692, acc: 0.784
[19,2868], precision: 0.695, acc: 0.784
[20,176], precision: 0.700, acc: 0.785
[20,432], precision: 0.700, acc: 0.784
[20,688], precision: 0.701, acc: 0.784
[20,944], precision: 0.699, acc: 0.784
[20,1200], 

[15,1348], precision: 0.692, acc: 0.781
[15,1604], precision: 0.696, acc: 0.782
[15,1860], precision: 0.698, acc: 0.782
[15,2116], precision: 0.691, acc: 0.782
[15,2372], precision: 0.687, acc: 0.782
[15,2628], precision: 0.683, acc: 0.782
[15,2884], precision: 0.694, acc: 0.783
[16,192], precision: 0.696, acc: 0.782
[16,448], precision: 0.694, acc: 0.783
[16,704], precision: 0.694, acc: 0.782
[16,960], precision: 0.697, acc: 0.783
[16,1216], precision: 0.687, acc: 0.782
[16,1472], precision: 0.695, acc: 0.783
[16,1728], precision: 0.697, acc: 0.783
[16,1984], precision: 0.698, acc: 0.784
[16,2240], precision: 0.694, acc: 0.783
[16,2496], precision: 0.689, acc: 0.782
[16,2752], precision: 0.691, acc: 0.783
[17,60], precision: 0.699, acc: 0.783
[17,316], precision: 0.698, acc: 0.783
[17,572], precision: 0.696, acc: 0.783
[17,828], precision: 0.696, acc: 0.783
[17,1084], precision: 0.690, acc: 0.783
[17,1340], precision: 0.691, acc: 0.784
[17,1596], precision: 0.699, acc: 0.784
[17,1852]

[12,2000], precision: 0.690, acc: 0.780
[12,2256], precision: 0.689, acc: 0.780
[12,2512], precision: 0.688, acc: 0.780
[12,2768], precision: 0.690, acc: 0.780
[13,76], precision: 0.695, acc: 0.781
[13,332], precision: 0.687, acc: 0.779
[13,588], precision: 0.687, acc: 0.780
[13,844], precision: 0.690, acc: 0.780
[13,1100], precision: 0.685, acc: 0.779
[13,1356], precision: 0.688, acc: 0.780
[13,1612], precision: 0.693, acc: 0.781
[13,1868], precision: 0.693, acc: 0.781
[13,2124], precision: 0.690, acc: 0.780
[13,2380], precision: 0.691, acc: 0.781
[13,2636], precision: 0.689, acc: 0.781
[13,2892], precision: 0.693, acc: 0.781
[14,200], precision: 0.695, acc: 0.781
[14,456], precision: 0.693, acc: 0.780
[14,712], precision: 0.692, acc: 0.780
[14,968], precision: 0.690, acc: 0.782
[14,1224], precision: 0.688, acc: 0.782
[14,1480], precision: 0.694, acc: 0.781
[14,1736], precision: 0.695, acc: 0.781
[14,1992], precision: 0.694, acc: 0.782
[14,2248], precision: 0.691, acc: 0.782
[14,2504]

[9,2652], precision: 0.679, acc: 0.773
[9,2908], precision: 0.686, acc: 0.774
[10,216], precision: 0.685, acc: 0.774
[10,472], precision: 0.686, acc: 0.775
[10,728], precision: 0.687, acc: 0.775
[10,984], precision: 0.684, acc: 0.774
[10,1240], precision: 0.683, acc: 0.775
[10,1496], precision: 0.684, acc: 0.775
[10,1752], precision: 0.688, acc: 0.775
[10,2008], precision: 0.685, acc: 0.775
[10,2264], precision: 0.686, acc: 0.776
[10,2520], precision: 0.680, acc: 0.775
[10,2776], precision: 0.683, acc: 0.776
[11,84], precision: 0.695, acc: 0.777
[11,340], precision: 0.692, acc: 0.776
[11,596], precision: 0.687, acc: 0.776
[11,852], precision: 0.689, acc: 0.777
[11,1108], precision: 0.681, acc: 0.776
[11,1364], precision: 0.686, acc: 0.777
[11,1620], precision: 0.688, acc: 0.777
[11,1876], precision: 0.686, acc: 0.777
[11,2132], precision: 0.683, acc: 0.778
[11,2388], precision: 0.684, acc: 0.778
[11,2644], precision: 0.681, acc: 0.778
[11,2900], precision: 0.692, acc: 0.778
[12,208], p

[7,100], precision: 0.679, acc: 0.765
[7,356], precision: 0.673, acc: 0.765
[7,612], precision: 0.672, acc: 0.766
[7,868], precision: 0.671, acc: 0.765
[7,1124], precision: 0.670, acc: 0.766
[7,1380], precision: 0.679, acc: 0.766
[7,1636], precision: 0.681, acc: 0.767
[7,1892], precision: 0.684, acc: 0.767
[7,2148], precision: 0.674, acc: 0.767
[7,2404], precision: 0.675, acc: 0.768
[7,2660], precision: 0.673, acc: 0.768
[7,2916], precision: 0.678, acc: 0.768
[8,224], precision: 0.684, acc: 0.770
[8,480], precision: 0.675, acc: 0.769
[8,736], precision: 0.678, acc: 0.769
[8,992], precision: 0.675, acc: 0.769
[8,1248], precision: 0.679, acc: 0.770
[8,1504], precision: 0.682, acc: 0.770
[8,1760], precision: 0.684, acc: 0.771
[8,2016], precision: 0.677, acc: 0.770
[8,2272], precision: 0.676, acc: 0.771
[8,2528], precision: 0.680, acc: 0.772
[8,2784], precision: 0.677, acc: 0.771
[9,92], precision: 0.684, acc: 0.772
[9,348], precision: 0.680, acc: 0.772
[9,604], precision: 0.682, acc: 0.77

In [8]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_LI), sta.mean(recall_LI), sta.mean(f1score_LI), sta.mean(accuracy_LI)))

prec=0.694, recall=0.545, F1=0.610, acc=0.786


### Less Attention GIT Model

In [9]:
precision_LA, recall_LA, f1score_LA, accuracy_LA = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args.initializtion=True
# args.attention=False -> GIT-attn
args.attention=False
# args.cancer_type=False -> GIT-can
args.cancer_type=True

args.max_iter = 3072*40
args.learning_rate=0.0003

for i in range(1):
    
    # Init model with single hidden layer
    torch.cuda.empty_cache()
    model = GIT(args).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*40,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_LA.append(precision)
    recall_LA.append(recall)
    f1score_LA.append(f1score)
    accuracy_LA.append(accuracy)



[0,0], precision: 0.317, acc: 0.519
[0,256], precision: 0.382, acc: 0.620
[0,512], precision: 0.360, acc: 0.595
[0,768], precision: 0.347, acc: 0.577
[0,1024], precision: 0.346, acc: 0.576
[0,1280], precision: 0.349, acc: 0.581
[0,1536], precision: 0.358, acc: 0.594
[0,1792], precision: 0.386, acc: 0.628
[0,2048], precision: 0.402, acc: 0.641
[0,2304], precision: 0.431, acc: 0.662
[0,2560], precision: 0.434, acc: 0.663
[0,2816], precision: 0.454, acc: 0.673
[1,124], precision: 0.471, acc: 0.681
[1,380], precision: 0.465, acc: 0.678
[1,636], precision: 0.484, acc: 0.686
[1,892], precision: 0.492, acc: 0.689
[1,1148], precision: 0.501, acc: 0.692
[1,1404], precision: 0.500, acc: 0.692
[1,1660], precision: 0.512, acc: 0.696
[1,1916], precision: 0.513, acc: 0.696
[1,2172], precision: 0.516, acc: 0.697
[1,2428], precision: 0.525, acc: 0.699
[1,2684], precision: 0.521, acc: 0.699
[1,2940], precision: 0.525, acc: 0.700
[2,248], precision: 0.535, acc: 0.703
[2,504], precision: 0.532, acc: 0.70

[18,696], precision: 0.669, acc: 0.770
[18,952], precision: 0.674, acc: 0.770
[18,1208], precision: 0.677, acc: 0.771
[18,1464], precision: 0.686, acc: 0.771
[18,1720], precision: 0.683, acc: 0.772
[18,1976], precision: 0.676, acc: 0.772
[18,2232], precision: 0.679, acc: 0.771
[18,2488], precision: 0.667, acc: 0.770
[18,2744], precision: 0.673, acc: 0.773
[19,52], precision: 0.678, acc: 0.772
[19,308], precision: 0.672, acc: 0.772
[19,564], precision: 0.679, acc: 0.772
[19,820], precision: 0.676, acc: 0.772
[19,1076], precision: 0.675, acc: 0.772
[19,1332], precision: 0.676, acc: 0.772
[19,1588], precision: 0.681, acc: 0.773
[19,1844], precision: 0.679, acc: 0.773
[19,2100], precision: 0.681, acc: 0.773
[19,2356], precision: 0.671, acc: 0.772
[19,2612], precision: 0.669, acc: 0.773
[19,2868], precision: 0.678, acc: 0.773
[20,176], precision: 0.680, acc: 0.773
[20,432], precision: 0.675, acc: 0.774
[20,688], precision: 0.678, acc: 0.774
[20,944], precision: 0.677, acc: 0.774
[20,1200], 

[36,624], precision: 0.668, acc: 0.778
[36,880], precision: 0.672, acc: 0.780
[36,1136], precision: 0.669, acc: 0.778
[36,1392], precision: 0.678, acc: 0.779
[36,1648], precision: 0.682, acc: 0.779
[36,1904], precision: 0.668, acc: 0.780
[36,2160], precision: 0.671, acc: 0.779
[36,2416], precision: 0.665, acc: 0.778
[36,2672], precision: 0.668, acc: 0.779
[36,2928], precision: 0.677, acc: 0.781
[37,236], precision: 0.669, acc: 0.780
[37,492], precision: 0.673, acc: 0.780
[37,748], precision: 0.672, acc: 0.779
[37,1004], precision: 0.669, acc: 0.779
[37,1260], precision: 0.673, acc: 0.779
[37,1516], precision: 0.675, acc: 0.778
[37,1772], precision: 0.682, acc: 0.781
[37,2028], precision: 0.670, acc: 0.779
[37,2284], precision: 0.669, acc: 0.779
[37,2540], precision: 0.669, acc: 0.780
[37,2796], precision: 0.671, acc: 0.780
[38,104], precision: 0.672, acc: 0.778
[38,360], precision: 0.673, acc: 0.779
[38,616], precision: 0.677, acc: 0.779
[38,872], precision: 0.669, acc: 0.779
[38,1128]

In [10]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_LA), sta.mean(recall_LA), sta.mean(f1score_LA), sta.mean(accuracy_LA)))

prec=0.674, recall=0.552, F1=0.607, acc=0.779


### Less Cancer Embedding GIT Model

In [11]:
precision_LC, recall_LC, f1score_LC, accuracy_LC = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args.initializtion=True
# args.attention=False -> GIT-attn
args.attention=True
# args.cancer_type=False -> GIT-can
args.cancer_type=False

args.max_iter = 3072*40
args.learning_rate= 1e-4

for i in range(1):
    
    # Init model with single hidden layer
    torch.cuda.empty_cache()
    model = GIT(args).to(device)

    # Train MLP model
    model.train(train_set, test_set,
          batch_size=16,
          test_batch_size=512,
          max_iter=3072*20,
          max_fscore=0.7,
          test_inc_size=256)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_LC.append(precision)
    recall_LC.append(recall)
    f1score_LC.append(f1score)
    accuracy_LC.append(accuracy)



[0,0], precision: 0.311, acc: 0.506
[0,256], precision: 0.357, acc: 0.590
[0,512], precision: 0.391, acc: 0.629
[0,768], precision: 0.415, acc: 0.648
[0,1024], precision: 0.431, acc: 0.658
[0,1280], precision: 0.444, acc: 0.666
[0,1536], precision: 0.455, acc: 0.672
[0,1792], precision: 0.467, acc: 0.679
[0,2048], precision: 0.472, acc: 0.681
[0,2304], precision: 0.477, acc: 0.683
[0,2560], precision: 0.483, acc: 0.685
[0,2816], precision: 0.486, acc: 0.686
[1,124], precision: 0.494, acc: 0.689
[1,380], precision: 0.500, acc: 0.692
[1,636], precision: 0.503, acc: 0.693
[1,892], precision: 0.503, acc: 0.693
[1,1148], precision: 0.504, acc: 0.693
[1,1404], precision: 0.510, acc: 0.695
[1,1660], precision: 0.512, acc: 0.696
[1,1916], precision: 0.517, acc: 0.697
[1,2172], precision: 0.518, acc: 0.698
[1,2428], precision: 0.520, acc: 0.698
[1,2684], precision: 0.520, acc: 0.698
[1,2940], precision: 0.524, acc: 0.700
[2,248], precision: 0.529, acc: 0.701
[2,504], precision: 0.529, acc: 0.70

[18,696], precision: 0.626, acc: 0.742
[18,952], precision: 0.631, acc: 0.742
[18,1208], precision: 0.635, acc: 0.743
[18,1464], precision: 0.641, acc: 0.743
[18,1720], precision: 0.638, acc: 0.743
[18,1976], precision: 0.636, acc: 0.743
[18,2232], precision: 0.632, acc: 0.742
[18,2488], precision: 0.633, acc: 0.743
[18,2744], precision: 0.634, acc: 0.743
[19,52], precision: 0.643, acc: 0.744
[19,308], precision: 0.633, acc: 0.743
[19,564], precision: 0.624, acc: 0.742
[19,820], precision: 0.633, acc: 0.743
[19,1076], precision: 0.632, acc: 0.742
[19,1332], precision: 0.640, acc: 0.743
[19,1588], precision: 0.640, acc: 0.743
[19,1844], precision: 0.642, acc: 0.744
[19,2100], precision: 0.634, acc: 0.743
[19,2356], precision: 0.632, acc: 0.743
[19,2612], precision: 0.635, acc: 0.744
[19,2868], precision: 0.639, acc: 0.744
[20,176], precision: 0.638, acc: 0.744
[20,432], precision: 0.627, acc: 0.744
[20,688], precision: 0.630, acc: 0.744
[20,944], precision: 0.634, acc: 0.744
[20,1200], 

In [12]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_LC), sta.mean(recall_LC), sta.mean(f1score_LC), sta.mean(accuracy_LC)))

prec=0.631, recall=0.405, F1=0.494, acc=0.744


In [13]:
# Check if CUDA memory exhausted
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
r - a

0

### Modified input test

To test the GIT variants,
change either of the boolean varible below as false:
```
args_dg.initializtion = True
args_dg.attention = True
args_dg.cancer_type = True
```

In [14]:
# Parse arguments
args_dg = SimpleNamespace()

args_dg.train_model=True

args_dg.input_dir="data_degraded"
args_dg.output_dir="data_degraded"

args_dg.embedding_size=512
args_dg.hidden_size=1024
args_dg.attention_size=400
args_dg.attention_head=128

args_dg.max_fscore=0.7
args_dg.batch_size=16
args_dg.test_batch_size=512
args_dg.test_inc_size=256
args_dg.dropout_rate=0.5
args_dg.weight_decay=1e-5

args_dg.deg_shuffle=False

# Load data
dataset_dg = load_dataset(input_dir=args_dg.input_dir, deg_shuffle=args_dg.deg_shuffle)
train_set_dg, test_set_dg = split_dataset(dataset_dg, ratio=0.66)

args_dg.can_size = dataset_dg["can"].max()        # cancer type dimension
args_dg.sga_size = max(dataset_dg["sga"].max(), 19781)        # SGA dimension
args_dg.deg_size = dataset_dg["deg"].shape[1]     # DEG output dimension
args_dg.num_max_sga = dataset_dg["sga"].shape[1]  # maximum number of SGAs in a tumor

In [21]:
precision_dg_GIT, recall_dg_GIT, f1score_dg_GIT, accuracy_dg_GIT = [], [], [], []

# GIT variants:
# args.initializtion=False -> GIT-init
args_dg.initializtion=True
# args.attention=False -> GIT-attn
args_dg.attention=True
# args.cancer_type=False -> GIT-can
args_dg.cancer_type=False

args_dg.max_iter=3072*20
args_dg.learning_rate=1e-4

if args_dg.cancer_type == False:
    args_dg.max_iter = 3072*40
elif args_dg.attention == False:
    args_dg.max_iter = 3072*40
    args_dg.learning_rate = 0.0003

for i in range(5):
    
    # Init model with single hidden layer
    model = GIT(args_dg).to(device)

    # Train MLP model
    model.train(train_set_dg, test_set_dg,
          batch_size=args_dg.batch_size,
          test_batch_size=args_dg.test_batch_size,
          max_iter=args_dg.max_iter,
          max_fscore=args_dg.max_fscore,
          test_inc_size=args_dg.test_inc_size)

    print("Evaluating...")
    labels, preds, _, _, _, _, _ = model.test(test_set_dg, test_batch_size=512)
    precision, recall, f1score, accuracy = evaluate(labels, preds, epsilon=1e-4)
    print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%(precision, recall, f1score, accuracy))
    
    # Release memory
    del model
    torch.cuda.empty_cache()
    
    precision_dg_GIT.append(precision)
    recall_dg_GIT.append(recall)
    f1score_dg_GIT.append(f1score)
    accuracy_dg_GIT.append(accuracy)



[0,0], precision: 0.310, acc: 0.505
[0,256], precision: 0.359, acc: 0.594
[0,512], precision: 0.399, acc: 0.636
[0,768], precision: 0.430, acc: 0.658
[0,1024], precision: 0.448, acc: 0.668
[0,1280], precision: 0.459, acc: 0.674
[0,1536], precision: 0.471, acc: 0.680
[0,1792], precision: 0.480, acc: 0.684
[0,2048], precision: 0.483, acc: 0.685
[0,2304], precision: 0.488, acc: 0.687
[0,2560], precision: 0.493, acc: 0.689
[0,2816], precision: 0.496, acc: 0.690
[1,124], precision: 0.502, acc: 0.692
[1,380], precision: 0.506, acc: 0.694
[1,636], precision: 0.510, acc: 0.695
[1,892], precision: 0.511, acc: 0.696
[1,1148], precision: 0.513, acc: 0.696
[1,1404], precision: 0.517, acc: 0.697
[1,1660], precision: 0.521, acc: 0.699
[1,1916], precision: 0.524, acc: 0.699
[1,2172], precision: 0.523, acc: 0.699
[1,2428], precision: 0.528, acc: 0.701
[1,2684], precision: 0.526, acc: 0.700
[1,2940], precision: 0.530, acc: 0.701
[2,248], precision: 0.535, acc: 0.702
[2,504], precision: 0.534, acc: 0.70

[18,696], precision: 0.607, acc: 0.724
[18,952], precision: 0.605, acc: 0.724
[18,1208], precision: 0.603, acc: 0.724
[18,1464], precision: 0.607, acc: 0.724
[18,1720], precision: 0.606, acc: 0.725
[18,1976], precision: 0.605, acc: 0.725
[18,2232], precision: 0.597, acc: 0.723
[18,2488], precision: 0.598, acc: 0.724
[18,2744], precision: 0.595, acc: 0.724
[19,52], precision: 0.609, acc: 0.724
[19,308], precision: 0.608, acc: 0.724
[19,564], precision: 0.606, acc: 0.724
[19,820], precision: 0.604, acc: 0.724
[19,1076], precision: 0.600, acc: 0.724
[19,1332], precision: 0.603, acc: 0.724
[19,1588], precision: 0.602, acc: 0.724
[19,1844], precision: 0.604, acc: 0.724
[19,2100], precision: 0.595, acc: 0.724
[19,2356], precision: 0.594, acc: 0.723
[19,2612], precision: 0.592, acc: 0.723
[19,2868], precision: 0.604, acc: 0.724
[20,176], precision: 0.602, acc: 0.724
[20,432], precision: 0.600, acc: 0.723
[20,688], precision: 0.601, acc: 0.723
[20,944], precision: 0.596, acc: 0.723
[20,1200], 

[36,624], precision: 0.553, acc: 0.712
[36,880], precision: 0.550, acc: 0.711
[36,1136], precision: 0.550, acc: 0.712
[36,1392], precision: 0.553, acc: 0.713
[36,1648], precision: 0.559, acc: 0.714
[36,1904], precision: 0.554, acc: 0.712
[36,2160], precision: 0.550, acc: 0.712
[36,2416], precision: 0.548, acc: 0.711
[36,2672], precision: 0.545, acc: 0.711
[36,2928], precision: 0.554, acc: 0.713
[37,236], precision: 0.556, acc: 0.714
[37,492], precision: 0.553, acc: 0.712
[37,748], precision: 0.550, acc: 0.712
[37,1004], precision: 0.547, acc: 0.710
[37,1260], precision: 0.550, acc: 0.712
[37,1516], precision: 0.555, acc: 0.713
[37,1772], precision: 0.555, acc: 0.713
[37,2028], precision: 0.553, acc: 0.712
[37,2284], precision: 0.546, acc: 0.711
[37,2540], precision: 0.545, acc: 0.710
[37,2796], precision: 0.547, acc: 0.711
[38,104], precision: 0.550, acc: 0.711
[38,360], precision: 0.553, acc: 0.712
[38,616], precision: 0.554, acc: 0.712
[38,872], precision: 0.548, acc: 0.711
[38,1128]

[12,1744], precision: 0.612, acc: 0.723
[12,2000], precision: 0.608, acc: 0.722
[12,2256], precision: 0.611, acc: 0.723
[12,2512], precision: 0.608, acc: 0.723
[12,2768], precision: 0.607, acc: 0.723
[13,76], precision: 0.617, acc: 0.723
[13,332], precision: 0.614, acc: 0.723
[13,588], precision: 0.612, acc: 0.723
[13,844], precision: 0.611, acc: 0.723
[13,1100], precision: 0.608, acc: 0.723
[13,1356], precision: 0.610, acc: 0.724
[13,1612], precision: 0.613, acc: 0.724
[13,1868], precision: 0.614, acc: 0.723
[13,2124], precision: 0.610, acc: 0.724
[13,2380], precision: 0.610, acc: 0.724
[13,2636], precision: 0.608, acc: 0.724
[13,2892], precision: 0.615, acc: 0.724
[14,200], precision: 0.618, acc: 0.724
[14,456], precision: 0.613, acc: 0.724
[14,712], precision: 0.615, acc: 0.724
[14,968], precision: 0.610, acc: 0.724
[14,1224], precision: 0.609, acc: 0.724
[14,1480], precision: 0.613, acc: 0.724
[14,1736], precision: 0.618, acc: 0.725
[14,1992], precision: 0.613, acc: 0.724
[14,2248]

[30,1672], precision: 0.570, acc: 0.718
[30,1928], precision: 0.571, acc: 0.717
[30,2184], precision: 0.571, acc: 0.718
[30,2440], precision: 0.565, acc: 0.717
[30,2696], precision: 0.559, acc: 0.715
[31,4], precision: 0.571, acc: 0.718
[31,260], precision: 0.571, acc: 0.718
[31,516], precision: 0.569, acc: 0.716
[31,772], precision: 0.568, acc: 0.717
[31,1028], precision: 0.561, acc: 0.715
[31,1284], precision: 0.563, acc: 0.716
[31,1540], precision: 0.569, acc: 0.717
[31,1796], precision: 0.568, acc: 0.717
[31,2052], precision: 0.570, acc: 0.718
[31,2308], precision: 0.565, acc: 0.716
[31,2564], precision: 0.558, acc: 0.715
[31,2820], precision: 0.558, acc: 0.715
[32,128], precision: 0.561, acc: 0.715
[32,384], precision: 0.562, acc: 0.715
[32,640], precision: 0.571, acc: 0.717
[32,896], precision: 0.558, acc: 0.715
[32,1152], precision: 0.559, acc: 0.716
[32,1408], precision: 0.566, acc: 0.716
[32,1664], precision: 0.570, acc: 0.717
[32,1920], precision: 0.568, acc: 0.717
[32,2176],

[6,2536], precision: 0.583, acc: 0.716
[6,2792], precision: 0.583, acc: 0.716
[7,100], precision: 0.588, acc: 0.716
[7,356], precision: 0.586, acc: 0.716
[7,612], precision: 0.586, acc: 0.717
[7,868], precision: 0.587, acc: 0.717
[7,1124], precision: 0.586, acc: 0.717
[7,1380], precision: 0.587, acc: 0.717
[7,1636], precision: 0.586, acc: 0.717
[7,1892], precision: 0.589, acc: 0.717
[7,2148], precision: 0.588, acc: 0.717
[7,2404], precision: 0.592, acc: 0.717
[7,2660], precision: 0.586, acc: 0.717
[7,2916], precision: 0.592, acc: 0.718
[8,224], precision: 0.592, acc: 0.718
[8,480], precision: 0.588, acc: 0.717
[8,736], precision: 0.589, acc: 0.717
[8,992], precision: 0.589, acc: 0.718
[8,1248], precision: 0.590, acc: 0.718
[8,1504], precision: 0.594, acc: 0.719
[8,1760], precision: 0.595, acc: 0.718
[8,2016], precision: 0.593, acc: 0.718
[8,2272], precision: 0.595, acc: 0.719
[8,2528], precision: 0.594, acc: 0.719
[8,2784], precision: 0.592, acc: 0.719
[9,92], precision: 0.598, acc: 0.

[24,2720], precision: 0.576, acc: 0.720
[25,28], precision: 0.590, acc: 0.721
[25,284], precision: 0.583, acc: 0.720
[25,540], precision: 0.582, acc: 0.720
[25,796], precision: 0.581, acc: 0.720
[25,1052], precision: 0.581, acc: 0.721
[25,1308], precision: 0.581, acc: 0.720
[25,1564], precision: 0.591, acc: 0.721
[25,1820], precision: 0.589, acc: 0.721
[25,2076], precision: 0.581, acc: 0.721
[25,2332], precision: 0.578, acc: 0.720
[25,2588], precision: 0.583, acc: 0.722
[25,2844], precision: 0.583, acc: 0.721
[26,152], precision: 0.587, acc: 0.721
[26,408], precision: 0.586, acc: 0.721
[26,664], precision: 0.585, acc: 0.720
[26,920], precision: 0.579, acc: 0.720
[26,1176], precision: 0.572, acc: 0.719
[26,1432], precision: 0.580, acc: 0.720
[26,1688], precision: 0.583, acc: 0.720
[26,1944], precision: 0.582, acc: 0.720
[26,2200], precision: 0.577, acc: 0.720
[26,2456], precision: 0.579, acc: 0.720
[26,2712], precision: 0.571, acc: 0.719
[27,20], precision: 0.582, acc: 0.720
[27,276], p

[1,124], precision: 0.503, acc: 0.693
[1,380], precision: 0.508, acc: 0.695
[1,636], precision: 0.511, acc: 0.695
[1,892], precision: 0.512, acc: 0.696
[1,1148], precision: 0.514, acc: 0.696
[1,1404], precision: 0.518, acc: 0.698
[1,1660], precision: 0.520, acc: 0.698
[1,1916], precision: 0.524, acc: 0.700
[1,2172], precision: 0.524, acc: 0.699
[1,2428], precision: 0.528, acc: 0.701
[1,2684], precision: 0.528, acc: 0.701
[1,2940], precision: 0.530, acc: 0.701
[2,248], precision: 0.534, acc: 0.702
[2,504], precision: 0.533, acc: 0.702
[2,760], precision: 0.534, acc: 0.703
[2,1016], precision: 0.534, acc: 0.703
[2,1272], precision: 0.536, acc: 0.703
[2,1528], precision: 0.539, acc: 0.704
[2,1784], precision: 0.542, acc: 0.705
[2,2040], precision: 0.543, acc: 0.705
[2,2296], precision: 0.547, acc: 0.706
[2,2552], precision: 0.546, acc: 0.706
[2,2808], precision: 0.545, acc: 0.706
[3,116], precision: 0.550, acc: 0.707
[3,372], precision: 0.549, acc: 0.707
[3,628], precision: 0.548, acc: 0.

[19,820], precision: 0.608, acc: 0.725
[19,1076], precision: 0.600, acc: 0.724
[19,1332], precision: 0.606, acc: 0.725
[19,1588], precision: 0.603, acc: 0.724
[19,1844], precision: 0.606, acc: 0.725
[19,2100], precision: 0.600, acc: 0.724
[19,2356], precision: 0.599, acc: 0.724
[19,2612], precision: 0.591, acc: 0.723
[19,2868], precision: 0.599, acc: 0.724
[20,176], precision: 0.609, acc: 0.725
[20,432], precision: 0.607, acc: 0.725
[20,688], precision: 0.606, acc: 0.725
[20,944], precision: 0.598, acc: 0.724
[20,1200], precision: 0.597, acc: 0.724
[20,1456], precision: 0.604, acc: 0.725
[20,1712], precision: 0.600, acc: 0.724
[20,1968], precision: 0.596, acc: 0.724
[20,2224], precision: 0.597, acc: 0.724
[20,2480], precision: 0.598, acc: 0.724
[20,2736], precision: 0.589, acc: 0.723
[21,44], precision: 0.604, acc: 0.724
[21,300], precision: 0.602, acc: 0.724
[21,556], precision: 0.601, acc: 0.724
[21,812], precision: 0.600, acc: 0.724
[21,1068], precision: 0.593, acc: 0.723
[21,1324],

[37,748], precision: 0.554, acc: 0.712
[37,1004], precision: 0.547, acc: 0.711
[37,1260], precision: 0.549, acc: 0.712
[37,1516], precision: 0.556, acc: 0.713
[37,1772], precision: 0.556, acc: 0.713
[37,2028], precision: 0.549, acc: 0.712
[37,2284], precision: 0.550, acc: 0.712
[37,2540], precision: 0.550, acc: 0.712
[37,2796], precision: 0.552, acc: 0.713
[38,104], precision: 0.556, acc: 0.713
[38,360], precision: 0.558, acc: 0.714
[38,616], precision: 0.554, acc: 0.713
[38,872], precision: 0.552, acc: 0.712
[38,1128], precision: 0.545, acc: 0.711
[38,1384], precision: 0.550, acc: 0.712
[38,1640], precision: 0.564, acc: 0.715
[38,1896], precision: 0.553, acc: 0.712
[38,2152], precision: 0.550, acc: 0.712
[38,2408], precision: 0.546, acc: 0.711
[38,2664], precision: 0.549, acc: 0.712
[38,2920], precision: 0.554, acc: 0.713
[39,228], precision: 0.549, acc: 0.711
[39,484], precision: 0.557, acc: 0.713
[39,740], precision: 0.551, acc: 0.712
[39,996], precision: 0.549, acc: 0.712
[39,1252]

[13,1868], precision: 0.615, acc: 0.724
[13,2124], precision: 0.606, acc: 0.723
[13,2380], precision: 0.608, acc: 0.724
[13,2636], precision: 0.605, acc: 0.724
[13,2892], precision: 0.613, acc: 0.724
[14,200], precision: 0.616, acc: 0.724
[14,456], precision: 0.614, acc: 0.724
[14,712], precision: 0.611, acc: 0.723
[14,968], precision: 0.608, acc: 0.724
[14,1224], precision: 0.606, acc: 0.724
[14,1480], precision: 0.610, acc: 0.723
[14,1736], precision: 0.616, acc: 0.724
[14,1992], precision: 0.611, acc: 0.724
[14,2248], precision: 0.614, acc: 0.725
[14,2504], precision: 0.606, acc: 0.724
[14,2760], precision: 0.608, acc: 0.725
[15,68], precision: 0.617, acc: 0.725
[15,324], precision: 0.612, acc: 0.725
[15,580], precision: 0.615, acc: 0.725
[15,836], precision: 0.610, acc: 0.724
[15,1092], precision: 0.605, acc: 0.724
[15,1348], precision: 0.610, acc: 0.724
[15,1604], precision: 0.613, acc: 0.725
[15,1860], precision: 0.615, acc: 0.725
[15,2116], precision: 0.607, acc: 0.725
[15,2372]

[31,1796], precision: 0.573, acc: 0.717
[31,2052], precision: 0.564, acc: 0.716
[31,2308], precision: 0.563, acc: 0.716
[31,2564], precision: 0.559, acc: 0.714
[31,2820], precision: 0.563, acc: 0.716
[32,128], precision: 0.568, acc: 0.717
[32,384], precision: 0.568, acc: 0.716
[32,640], precision: 0.566, acc: 0.715
[32,896], precision: 0.560, acc: 0.715
[32,1152], precision: 0.560, acc: 0.715
[32,1408], precision: 0.565, acc: 0.716
[32,1664], precision: 0.568, acc: 0.717
[32,1920], precision: 0.564, acc: 0.716
[32,2176], precision: 0.561, acc: 0.715
[32,2432], precision: 0.559, acc: 0.714
[32,2688], precision: 0.557, acc: 0.714
[32,2944], precision: 0.562, acc: 0.715
[33,252], precision: 0.561, acc: 0.715
[33,508], precision: 0.561, acc: 0.715
[33,764], precision: 0.561, acc: 0.715
[33,1020], precision: 0.559, acc: 0.715
[33,1276], precision: 0.563, acc: 0.716
[33,1532], precision: 0.568, acc: 0.717
[33,1788], precision: 0.572, acc: 0.717
[33,2044], precision: 0.560, acc: 0.715
[33,230

In [22]:
print("prec=%.3f, recall=%.3f, F1=%.3f, acc=%.3f"%
      (sta.mean(precision_dg_GIT), sta.mean(recall_dg_GIT), sta.mean(f1score_dg_GIT), sta.mean(accuracy_dg_GIT)))

prec=0.547, recall=0.355, F1=0.430, acc=0.711
