In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset

from tqdm.notebook import tqdm
from TGA.utils import preprocessor
import copy

from time import time
import numpy as np
from itertools import repeat
from collections import Counter
from segtok import tokenizer as tk

from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction import stop_words
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin

In [2]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [4]:
torch.__version__

'1.7.1'

In [5]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/reut/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 10627)

In [6]:
class Tokenizer(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, stopwords='remove', model='list', lan='english', verbose=False):
        super(Tokenizer, self).__init__()
        self.mindf = mindf
        self.le = LabelEncoder()
        self.verbose = verbose
        self.stopwords = stopwords
        self.stopwordsSet = set(stop_words.ENGLISH_STOP_WORDS)
        self.lan = lan
        self.model = model
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor).build_analyzer()
        #self.analyzer = tk.web_tokenizer
    
    def fit(self, X, y):
        self.N = len(X)
        self.le.fit( y )
        self.n_class = len(self.le.classes_)

        self.term_freqs = Counter()
        docs = map(self.analyzer, X)
        for doc_in_terms in tqdm(docs, total=self.N, disable=not self.verbose):
            doc_in_terms = list(map( self._filter_fit_, doc_in_terms ))
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {'<BLANK>': 0}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self._get_idx_(term) for term in self.term_freqs.keys() if self._isrel_(term) }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.node_mapper['<BLANK>'] = 0
        self.vocab_size = len(self.node_mapper)
        
        return self
    def _isrel_(self, term):
        if self.stopwords == 'remove' and term in self.stopwordsSet:
            return False
        # put here your filter_functions
        return True
    def _get_idx_(self, term):
        # put here your idx_set_functions
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            print('is stop', term)
            return self.node_mapper.setdefault('<STPW>', len(self.node_mapper))
        return self.node_mapper.setdefault(term, len(self.node_mapper))
    def _filter_transform_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        if term not in self.node_mapper:
            return '<UNK>'
        return term
    def _filter_fit_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        return term
    def _model_(self, doc):
        if self.model == 'set':
            return set(doc)
        return list(doc)
    def transform(self, X, verbose=None):
        verbose = verbose if verbose is not None else self.verbose
        n = len(X)
        doc_off = [0]
        terms_idx = []
        for i,doc_in_terms in tqdm(enumerate(map(self.analyzer, X)), total=n, disable=not verbose):
            doc_in_terms = filter( self._isrel_, doc_in_terms )
            doc_in_terms = map( self._filter_transform_, doc_in_terms )
            doc_in_terms = self._model_(doc_in_terms)
            doc_in_terms = [ self.node_mapper[tid] for tid in doc_in_terms ]
            if self.model == 'sorted':
                doc_in_terms = sorted(doc_in_terms)
            doc_off.append( len(doc_in_terms) )
            terms_idx.extend( doc_in_terms )
        return np.array( terms_idx ), np.array(doc_off)[:-1].cumsum()

In [7]:
tokenizer = Tokenizer(mindf=1, stopwords='keep', model='set', verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)

HBox(children=(FloatProgress(value=0.0, max=10627.0), HTML(value='')))




Tokenizer(mindf=1, model='set', stopwords='keep', verbose=True)

In [8]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )
y_test  = tokenizer.le.transform( fold.y_test )

In [9]:
tokenizer.transform(fold.X_val[:2])

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




(array([  327,   294,  3744,  1418,  2039,     5,   783,     6,   691,
        13240,   910,  7850,    17,    83,   819,  5622,   574,    89,
         1916,  3362,  6770,  4594,    97,  3556,  4366,  2213,   349,
          108,   318,  2297,  1972,   828,  2888,   322,   113,  2216,
          116,   118,   120,   325,  3984,   975,  1215,  2301,     2,
          125,   594,  2868,  1807,  1274,  4811,  3347,   135,    16,
         2304, 17509, 10219,  1563,   842,   139,   539,  3361,   143,
          152,   378,  2059, 18130,   154,  5583,  2791,  1470,  1220,
         1998,   165,   170,  5643,  6547,   177, 16547,   178,   859,
          180, 25858,   392,  1572,  5769,    10,  1575,  1750,  1355,
         3165,  1578,  1185,  1132,  1447,   198,  4491,  6664,  1231,
           23,  1359,    24,   207,  5092,  2244,   757,  1583,   871,
         3387,  2248,   643,  5175,  1817,    39,   419,    41,   650,
          226,  2183,  1591,   884,   233,   664,  3358,    21,   239,
      

In [10]:
tokenizer.node_mapper['<BLANK>'], fold.X_val[0]

(0,
 '\r cocoa consumers narrow gap on buffer stock issue\r london march 16 representatives of cocoa consuming\r countries at an international cocoa organization icco council\r meeting here have edged closer to a unified stance on buffer\r stock rules delegates said \r while consumers do not yet have a common position an\r observer said after a consumer meeting they are much more\r fluid and the tone is positive \r european community consumers were split on the question of\r how the cocoa buffer stock should be operated when the icco met\r in january to put the new international cocoa agreement into\r effect delegates said \r at the january meeting france sided with producers on how\r the buffer stock should operate delegates said that meeting\r ended without agreement on new buffer stock rules \r the ec commission met in brussels on friday to see whether\r the 12 ec cocoa consuming nations could narrow their\r differences at this month s meeting \r the commissioners came away from the

In [11]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [12]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=2.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
        self.sig = nn.Sigmoid()
    def forward(self, h):
        w = F.leaky_relu( h, negative_slope=self.negative_slope)
        w = self.sig(w-self.kappa)
        return w

In [13]:
help(torch.embedding_bag)

Help on built-in function embedding_bag:

embedding_bag(...)



In [14]:
class SimpleAttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5, negative_slope=99.):
        super(SimpleAttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.tt_emb     = nn.Embedding(vocab_size, hiddens)
        self.tt_dir_map = nn.Linear(hiddens, hiddens)
        self.fc         = nn.Linear(hiddens, nclass)
        self.initrange  = initrange 
        self.negative_slope = negative_slope
        self.drop       = nn.Dropout(drop)
        self.drop_      = drop
        self.sig        = nn.Sigmoid()
        self.init_weights()
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h     = self.dt_emb( x_packed )
        dt_h     = F.dropout( dt_h, p=self.drop_, training=self.training )
        #dt_h     = self.drop(dt_h)
        
        tt_h     = self.tt_emb( x_packed )
        tt_h     = F.dropout( tt_h, p=self.drop_, training=self.training )
        #tt_h     = self.drop(tt_h)
        dir_tt_h = self.tt_dir_map( tt_h )

        weights = torch.bmm( tt_h, dir_tt_h.transpose( 1, 2 ) )
        weights = F.leaky_relu( weights, negative_slope=self.negative_slope)
        
        weights[pad_mask.logical_not()] = float('-inf') # Set the 3D-pad mask values to -inf (=0 in softmax)
        weights = F.sigmoid(weights)
        #weights = F.softmax(weights, dim=2) # Normalize the neighbors weights
        #weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights) #replace nan to zero
        weights = weights.sum(axis=2) / doc_sizes
        weights[bx_packed] = float('-inf') # Set the 2D-pad mask values to -inf  (=0 in softmax)
        weights = F.softmax(weights, dim=1)
        #weights = F.sigmoid(weights)
        weights = weights.view( *weights.shape, 1 )
        
        docs_h = dt_h * weights
        docs_h = docs_h.sum(axis=1)
        docs_h = F.dropout( docs_h, p=self.drop_, training=self.training )
        #docs_h = self.drop(docs_h)
        return self.fc(docs_h), weights
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)

In [30]:
nepochs = 1000
max_epochs = 50
drop=0.85
device = torch.device('cuda:0')
batch_size = 32

In [31]:
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f52ec0c53b0>

In [32]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
ab = SimpleAttentionBag(tokenizer.vocab_size, 300, tokenizer.n_class, drop=drop).to( device )

optimizer = optim.AdamW( ab.parameters(), lr=6e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.95,
                                                       patience=10, verbose=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=.98, verbose=True)

In [33]:


np.array([1])[1:]

array([], dtype=int64)

In [None]:
best = 99999.
counter = 1
eps=0.6
old_loss_train = 1.
dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=4)
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=4)
    loss_train  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        ab.train()
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs,_ = ab( terms_idx, docs_offsets)
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_train += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            
            toprint  = f"Train loss: {loss_train/(i+1):.5}/{loss.item():.5} "
            toprint += f'ACC: {correct/total:.5}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            del pred_docs, loss
            del terms_idx, docs_offsets, y
            del y_pred
        loss_train = loss_train/(i+1)
        total = 0
        correct  = 0
        ab.eval()
        with torch.no_grad():
            loss_val = 0.
            print()
            for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
                terms_idx    = terms_idx.to( device )
                docs_offsets = docs_offsets.to( device )
                y            = y.to( device )

                pred_docs,weights = ab( terms_idx, docs_offsets )
                pred_docs   = F.softmax(pred_docs)

                y_pred      = pred_docs.argmax(axis=1)
                correct    += (y_pred == y).sum().item()
                total      += len(y)
                loss2       = loss_func_cel(pred_docs, y)
                loss_val   += loss2

                print(f'Val loss: {loss_val.item()/(i+1):.5} ACC: {correct/total:.5}', end=f"{' '*100}\r")
   
                pbar.update( len(y) )

            del terms_idx, docs_offsets, y
            del y_pred
            
            loss_val   = loss_val/(i+1)
            
            old_loss_train = (1.-eps)*old_loss_train + eps*loss_train
            
            #prop_val   = loss_val.cpu() / (old_loss_train+optimizer.param_groups[0]['lr'])
            prop_val   = loss_val.cpu() / old_loss_train
            ## if loss_train < loss_val, prop_val > 1 --  Overfitting
            ## if loss_train > loss_val, prop_val < 1 -- Underfitting
            
            prop_learn = old_loss_train / loss_train
            ## if loss_train < old_loss_train, prop_learn > 1 -- Aprendendo
            ## if loss_train > old_loss_train, prop_learn < 1 -- Desaprendendo
            
            #prop = (2.*prop_val*prop_learn)/(prop_val+prop_learn)
            prop = np.sqrt(prop_val*prop_learn)
            ab.drop_ *= prop
            ab.drop_ = max(0.05, min(ab.drop_, 0.95))
            
            scheduler.step(loss_val)
            print()
            print(f"Set dropout to {ab.drop_:.4} (prop={prop_val:.4},{prop_learn:.4}={prop:.4})")
            #scheduler.step()

            if best-loss_val > 0.0001 :
                best = loss_val.item()
                counter = 1
                print(f'New Best Val loss: {best:.5}', end=f"{' '*100}\n")
                best_model = copy.deepcopy(ab).to('cpu')
            elif counter > max_epochs:
                print()
                print(f'Best Val loss: {best:.5}', end=f"{' '*100}\n")
                break
            else:
                counter += 1
            del pred_docs, loss2

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.0476/4.0614 ACC: 0.48819                                                                                                    
Val loss: 3.9793 ACC: 0.55852                                                                                                    
Set dropout to 0.8428 (prop=1.407,0.6988=0.9915)
New Best Val loss: 3.9793                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9925/3.8501 ACC: 0.54173                                                                                                    
Val loss: 3.9295 ACC: 0.61333                                                                                                    
Set dropout to 0.8361 (prop=1.114,0.8834=0.9921)
New Best Val loss: 3.9295                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9526/3.8507 ACC: 0.58182                                                                                                    
Val loss: 3.9101 ACC: 0.62296                                                                                                    
Set dropout to 0.8316 (prop=1.034,0.9569=0.9946)
New Best Val loss: 3.9101                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9304/3.9127 ACC: 0.59904                                                                                                    
Val loss: 3.9034 ACC: 0.62519                                                                                                    
Set dropout to 0.8288 (prop=1.008,0.9849=0.9966)
New Best Val loss: 3.9034                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9193/3.9386 ACC: 0.60826                                                                                                    
Val loss: 3.8995 ACC: 0.62815                                                                                                    
Set dropout to 0.8267 (prop=0.9999,0.9951=0.9975)
New Best Val loss: 3.8995                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9139/3.6049 ACC: 0.61043                                                                                                    
Val loss: 3.8986 ACC: 0.62519                                                                                                    
Set dropout to 0.825 (prop=0.9975,0.9986=0.998)
New Best Val loss: 3.8986                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9062/3.872 ACC: 0.6189                                                                                                      
Val loss: 3.8982 ACC: 0.6237                                                                                                     
Set dropout to 0.8242 (prop=0.9977,1.0=0.999)
New Best Val loss: 3.8982                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.9015/3.5215 ACC: 0.62153                                                                                                    
Val loss: 3.8968 ACC: 0.62667                                                                                                    
Set dropout to 0.8237 (prop=0.9982,1.001=0.9994)
New Best Val loss: 3.8968                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 3.901/3.852 ACC: 0.62125                                                                                                      
Val loss: 3.8954 ACC: 0.62593                                                                                                    
Set dropout to 0.8231 (prop=0.9983,1.0=0.9993)
New Best Val loss: 3.8954                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.897/3.852 ACC: 0.62595                                                                                                      
Val loss: 3.8951 ACC: 0.62593                                                                                                    
Set dropout to 0.8229 (prop=0.999,1.001=0.9998)
New Best Val loss: 3.8951                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.895/4.1845 ACC: 0.62821                                                                                                     
Val loss: 3.8909 ACC: 0.63037                                                                                                    
Set dropout to 0.8225 (prop=0.9985,1.0=0.9995)
New Best Val loss: 3.8909                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8887/3.8501 ACC: 0.63367                                                                                                    
Val loss: 3.8882 ACC: 0.63481                                                                                                    
Set dropout to 0.8224 (prop=0.9991,1.001=0.9999)
New Best Val loss: 3.8882                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8844/3.8496 ACC: 0.63819                                                                                                    
Val loss: 3.8861 ACC: 0.63704                                                                                                    
Set dropout to 0.8226 (prop=0.9997,1.001=1.0)
New Best Val loss: 3.8861                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8856/3.8521 ACC: 0.63677                                                                                                    
Val loss: 3.8847 ACC: 0.63778                                                                                                    
Set dropout to 0.8225 (prop=0.9996,1.0=0.9999)
New Best Val loss: 3.8847                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8782/3.852 ACC: 0.64402                                                                                                     
Val loss: 3.8834 ACC: 0.6363                                                                                                     
Set dropout to 0.823 (prop=1.0,1.001=1.001)
New Best Val loss: 3.8834                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8758/4.5176 ACC: 0.64872                                                                                                    
Val loss: 3.8756 ACC: 0.64593                                                                                                    
Set dropout to 0.823 (prop=0.9994,1.001=1.0)
New Best Val loss: 3.8756                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8673/3.5187 ACC: 0.65672                                                                                                    
Val loss: 3.8735 ACC: 0.64741                                                                                                    
Set dropout to 0.8237 (prop=1.001,1.001=1.001)
New Best Val loss: 3.8735                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8637/3.5848 ACC: 0.65983                                                                                                    
Val loss: 3.8711 ACC: 0.64963                                                                                                    
Set dropout to 0.8245 (prop=1.001,1.001=1.001)
New Best Val loss: 3.8711                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8608/3.7652 ACC: 0.6619                                                                                                     
Val loss: 3.8687 ACC: 0.65407                                                                                                    
Set dropout to 0.8253 (prop=1.001,1.001=1.001)
New Best Val loss: 3.8687                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8574/3.5242 ACC: 0.66359                                                                                                    
Val loss: 3.8671 ACC: 0.65407                                                                                                    
Set dropout to 0.8264 (prop=1.002,1.001=1.001)
New Best Val loss: 3.8671                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8543/3.851 ACC: 0.66905                                                                                                     
Val loss: 3.8644 ACC: 0.65778                                                                                                    
Set dropout to 0.8274 (prop=1.002,1.001=1.001)
New Best Val loss: 3.8644                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8542/3.5187 ACC: 0.66736                                                                                                    
Val loss: 3.8627 ACC: 0.66                                                                                                       
Set dropout to 0.8283 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8627                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8522/4.1871 ACC: 0.67281                                                                                                    
Val loss: 3.861 ACC: 0.66296                                                                                                     
Set dropout to 0.8293 (prop=1.002,1.0=1.001)
New Best Val loss: 3.861                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8499/3.5191 ACC: 0.67263                                                                                                    
Val loss: 3.8606 ACC: 0.66222                                                                                                    
Set dropout to 0.8304 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8606                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 25', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8494/3.8518 ACC: 0.6731                                                                                                     
Val loss: 3.8588 ACC: 0.66667                                                                                                    
Set dropout to 0.8315 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8588                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 26', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8513/4.1848 ACC: 0.67093                                                                                                    
Val loss: 3.8581 ACC: 0.66519                                                                                                    
Set dropout to 0.8322 (prop=1.002,0.9999=1.001)
New Best Val loss: 3.8581                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 27', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8481/3.5187 ACC: 0.673                                                                                                      
Val loss: 3.8559 ACC: 0.66889                                                                                                    
Set dropout to 0.833 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8559                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 28', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8484/3.5192 ACC: 0.67347                                                                                                    
Val loss: 3.856 ACC: 0.66667                                                                                                     
Set dropout to 0.8339 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 29', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8461/3.9957 ACC: 0.67658                                                                                                    
Val loss: 3.8551 ACC: 0.66519                                                                                                    
Set dropout to 0.8348 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8551                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 30', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8451/3.8508 ACC: 0.67733                                                                                                    
Val loss: 3.8544 ACC: 0.66815                                                                                                    
Set dropout to 0.8359 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8544                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 31', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8436/3.8529 ACC: 0.6778                                                                                                     
Val loss: 3.855 ACC: 0.66444                                                                                                     
Set dropout to 0.8371 (prop=1.003,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 32', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.843/3.8521 ACC: 0.67978                                                                                                     
Val loss: 3.8543 ACC: 0.66519                                                                                                    
Set dropout to 0.8383 (prop=1.003,1.0=1.001)
New Best Val loss: 3.8543                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 33', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8443/3.8259 ACC: 0.67686                                                                                                    
Val loss: 3.8547 ACC: 0.66889                                                                                                    
Set dropout to 0.8395 (prop=1.003,0.9999=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 34', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8439/3.8505 ACC: 0.6779                                                                                                     
Val loss: 3.8543 ACC: 0.66667                                                                                                    
Set dropout to 0.8406 (prop=1.003,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 35', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8428/3.8501 ACC: 0.67912                                                                                                    
Val loss: 3.8538 ACC: 0.66667                                                                                                    
Set dropout to 0.8418 (prop=1.003,1.0=1.001)
New Best Val loss: 3.8538                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 36', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8431/3.8508 ACC: 0.67997                                                                                                    
Val loss: 3.8546 ACC: 0.66519                                                                                                    
Set dropout to 0.8431 (prop=1.003,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 37', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8415/3.5187 ACC: 0.67931                                                                                                    
Val loss: 3.8537 ACC: 0.66889                                                                                                    
Set dropout to 0.8444 (prop=1.003,1.0=1.002)



HBox(children=(FloatProgress(value=0.0, description='Epoch 38', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8397/4.1846 ACC: 0.6842                                                                                                     
Val loss: 3.8536 ACC: 0.66741                                                                                                    
Set dropout to 0.8459 (prop=1.003,1.0=1.002)
New Best Val loss: 3.8536                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 39', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8371/3.5187 ACC: 0.68401                                                                                                    
Val loss: 3.8518 ACC: 0.66741                                                                                                    
Set dropout to 0.8476 (prop=1.003,1.0=1.002)
New Best Val loss: 3.8518                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 40', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8451/4.1849 ACC: 0.67714                                                                                                    
Val loss: 3.8525 ACC: 0.66889                                                                                                    
Set dropout to 0.8484 (prop=1.003,0.9993=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 41', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8447/4.1847 ACC: 0.67855                                                                                                    
Val loss: 3.8522 ACC: 0.66741                                                                                                    
Set dropout to 0.8492 (prop=1.002,0.9998=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 42', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8431/4.1828 ACC: 0.67987                                                                                                    
Val loss: 3.8507 ACC: 0.66889                                                                                                    
Set dropout to 0.85 (prop=1.002,1.0=1.001)
New Best Val loss: 3.8507                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 43', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8425/3.8307 ACC: 0.68025                                                                                                    
Val loss: 3.8512 ACC: 0.66815                                                                                                    
Set dropout to 0.851 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 44', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8421/3.5188 ACC: 0.67903                                                                                                    
Val loss: 3.8516 ACC: 0.66815                                                                                                    
Set dropout to 0.852 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 45', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8419/4.0972 ACC: 0.68175                                                                                                    
Val loss: 3.8528 ACC: 0.66815                                                                                                    
Set dropout to 0.8532 (prop=1.003,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 46', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8446/4.1854 ACC: 0.67695                                                                                                    
Val loss: 3.8518 ACC: 0.66815                                                                                                    
Set dropout to 0.854 (prop=1.002,0.9997=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 47', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8412/3.5187 ACC: 0.6795                                                                                                     
Val loss: 3.8515 ACC: 0.66667                                                                                                    
Set dropout to 0.8552 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 48', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8416/4.0897 ACC: 0.68091                                                                                                    
Val loss: 3.8511 ACC: 0.66889                                                                                                    
Set dropout to 0.8563 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 49', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8481/4.1848 ACC: 0.6747                                                                                                     
Val loss: 3.851 ACC: 0.66889                                                                                                     
Set dropout to 0.8566 (prop=1.001,0.9994=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 50', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8425/3.5191 ACC: 0.67855                                                                                                    
Val loss: 3.8513 ACC: 0.66963                                                                                                    
Set dropout to 0.8576 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 51', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8444/4.183 ACC: 0.67855                                                                                                     
Val loss: 3.8519 ACC: 0.66667                                                                                                    
Set dropout to 0.8584 (prop=1.002,0.9999=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 52', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8434/3.852 ACC: 0.67865                                                                                                     
Val loss: 3.85 ACC: 0.67037                                                                                                      
Set dropout to 0.8591 (prop=1.002,1.0=1.001)
New Best Val loss: 3.85                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 53', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8442/4.1853 ACC: 0.67855                                                                                                    
Val loss: 3.8507 ACC: 0.66815                                                                                                    
Set dropout to 0.8599 (prop=1.002,0.9999=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 54', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8423/3.5187 ACC: 0.67771                                                                                                    
Val loss: 3.8511 ACC: 0.66815                                                                                                    
Set dropout to 0.8608 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 55', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8501/4.1836 ACC: 0.6714                                                                                                     
Val loss: 3.8507 ACC: 0.66963                                                                                                    
Set dropout to 0.8609 (prop=1.001,0.9993=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 56', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8476/4.1851 ACC: 0.67479                                                                                                    
Val loss: 3.851 ACC: 0.66815                                                                                                     
Set dropout to 0.8613 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 57', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8459/3.5187 ACC: 0.67441                                                                                                    
Val loss: 3.8503 ACC: 0.66815                                                                                                    
Set dropout to 0.8618 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 58', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8436/4.1427 ACC: 0.6794                                                                                                     
Val loss: 3.85 ACC: 0.67111                                                                                                      
Set dropout to 0.8625 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 59', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8441/3.5187 ACC: 0.67564                                                                                                    
Val loss: 3.8496 ACC: 0.67037                                                                                                    
Set dropout to 0.8631 (prop=1.001,1.0=1.001)
New Best Val loss: 3.8496                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 60', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8435/4.1826 ACC: 0.67884                                                                                                    
Val loss: 3.8501 ACC: 0.66963                                                                                                    
Set dropout to 0.8638 (prop=1.002,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 61', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8443/3.8522 ACC: 0.6779                                                                                                     
Val loss: 3.8498 ACC: 0.66963                                                                                                    
Set dropout to 0.8644 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 62', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8503/4.1853 ACC: 0.67234                                                                                                    
Val loss: 3.8499 ACC: 0.66963                                                                                                    
Set dropout to 0.8644 (prop=1.001,0.9994=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 63', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8446/3.5187 ACC: 0.67554                                                                                                    
Val loss: 3.8492 ACC: 0.67185                                                                                                    
Set dropout to 0.8649 (prop=1.001,1.0=1.001)
New Best Val loss: 3.8492                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 64', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8472/3.5187 ACC: 0.673                                                                                                      
Val loss: 3.8484 ACC: 0.67333                                                                                                    
Set dropout to 0.8651 (prop=1.0,0.9999=1.0)
New Best Val loss: 3.8484                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 65', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8434/3.8512 ACC: 0.67837                                                                                                    
Val loss: 3.8494 ACC: 0.67185                                                                                                    
Set dropout to 0.8657 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 66', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8473/4.1855 ACC: 0.67611                                                                                                    
Val loss: 3.8485 ACC: 0.67185                                                                                                    
Set dropout to 0.8659 (prop=1.001,0.9997=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 67', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8455/3.5187 ACC: 0.67498                                                                                                    
Val loss: 3.8489 ACC: 0.67111                                                                                                    
Set dropout to 0.8662 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 68', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8448/3.8521 ACC: 0.67705                                                                                                    
Val loss: 3.849 ACC: 0.67185                                                                                                     
Set dropout to 0.8667 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 69', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8446/3.87 ACC: 0.67724                                                                                                      
Val loss: 3.8472 ACC: 0.67259                                                                                                    
Set dropout to 0.867 (prop=1.001,1.0=1.0)
New Best Val loss: 3.8472                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 70', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8468/4.1673 ACC: 0.67677                                                                                                    
Val loss: 3.8467 ACC: 0.67259                                                                                                    
Set dropout to 0.867 (prop=1.0,0.9998=1.0)
New Best Val loss: 3.8467                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 71', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8459/4.1836 ACC: 0.67536                                                                                                    
Val loss: 3.8461 ACC: 0.67259                                                                                                    
Set dropout to 0.867 (prop=1.0,1.0=1.0)
New Best Val loss: 3.8461                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 72', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8414/4.1929 ACC: 0.6811                                                                                                     
Val loss: 3.8469 ACC: 0.67185                                                                                                    
Set dropout to 0.8677 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 73', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8443/3.8506 ACC: 0.67695                                                                                                    
Val loss: 3.846 ACC: 0.67333                                                                                                     
Set dropout to 0.8678 (prop=1.001,0.9999=1.0)
New Best Val loss: 3.846                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 74', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8439/3.8521 ACC: 0.6778                                                                                                     
Val loss: 3.847 ACC: 0.67037                                                                                                     
Set dropout to 0.8682 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 75', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8426/3.5187 ACC: 0.6778                                                                                                     
Val loss: 3.8476 ACC: 0.66963                                                                                                    
Set dropout to 0.8687 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 76', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8454/3.5187 ACC: 0.67451                                                                                                    
Val loss: 3.8468 ACC: 0.67259                                                                                                    
Set dropout to 0.8689 (prop=1.001,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 77', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8468/3.8516 ACC: 0.6747                                                                                                     
Val loss: 3.8475 ACC: 0.67185                                                                                                    
Set dropout to 0.869 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 78', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8439/3.8514 ACC: 0.67714                                                                                                    
Val loss: 3.8464 ACC: 0.67481                                                                                                    
Set dropout to 0.8693 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 79', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8433/3.8516 ACC: 0.67799                                                                                                    
Val loss: 3.8462 ACC: 0.67407                                                                                                    
Set dropout to 0.8696 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 80', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.842/3.5189 ACC: 0.67903                                                                                                     
Val loss: 3.847 ACC: 0.67111                                                                                                     
Set dropout to 0.8701 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 81', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8475/3.851 ACC: 0.67423                                                                                                     
Val loss: 3.8468 ACC: 0.67111                                                                                                    
Set dropout to 0.8701 (prop=1.0,0.9995=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 82', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8438/3.8524 ACC: 0.67761                                                                                                    
Epoch    82: reducing learning rate of group 0 to 5.7000e-03.                                                                    

Set dropout to 0.8703 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 83', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8449/3.5241 ACC: 0.6762                                                                                                     
Val loss: 3.8464 ACC: 0.67407                                                                                                    
Set dropout to 0.8705 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 84', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.848/3.5187 ACC: 0.6715                                                                                                      
Val loss: 3.8473 ACC: 0.67111                                                                                                    
Set dropout to 0.8704 (prop=1.0,0.9997=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 85', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.846/3.5187 ACC: 0.67611                                                                                                     
Val loss: 3.8465 ACC: 0.67259                                                                                                    
Set dropout to 0.8705 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 86', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8446/3.8521 ACC: 0.67648                                                                                                    
Val loss: 3.8465 ACC: 0.67259                                                                                                    
Set dropout to 0.8707 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 87', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8442/3.85 ACC: 0.67695                                                                                                      
Val loss: 3.8474 ACC: 0.67111                                                                                                    
Set dropout to 0.871 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 88', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8444/3.5187 ACC: 0.67573                                                                                                    
Val loss: 3.8473 ACC: 0.67556                                                                                                    
Set dropout to 0.8713 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 89', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8448/4.1833 ACC: 0.67761                                                                                                    
Val loss: 3.8465 ACC: 0.67333                                                                                                    
Set dropout to 0.8715 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 90', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8434/3.9056 ACC: 0.67912                                                                                                    
Val loss: 3.8455 ACC: 0.67333                                                                                                    
Set dropout to 0.8718 (prop=1.0,1.0=1.0)
New Best Val loss: 3.8455                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 91', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8442/4.5174 ACC: 0.67884                                                                                                    
Val loss: 3.8457 ACC: 0.67407                                                                                                    
Set dropout to 0.872 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 92', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.843/3.5191 ACC: 0.6778                                                                                                      
Val loss: 3.846 ACC: 0.67037                                                                                                     
Set dropout to 0.8723 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 93', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8471/4.5176 ACC: 0.67677                                                                                                    
Val loss: 3.8455 ACC: 0.67259                                                                                                    
Set dropout to 0.8721 (prop=1.0,0.9996=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 94', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8471/4.5181 ACC: 0.67611                                                                                                    
Val loss: 3.8457 ACC: 0.67259                                                                                                    
Set dropout to 0.872 (prop=0.9998,0.9998=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 95', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8451/4.1854 ACC: 0.67771                                                                                                    
Val loss: 3.8461 ACC: 0.67481                                                                                                    
Set dropout to 0.8721 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 96', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.844/3.8476 ACC: 0.67677                                                                                                     
Val loss: 3.8453 ACC: 0.67407                                                                                                    
Set dropout to 0.8722 (prop=1.0,1.0=1.0)
New Best Val loss: 3.8453                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 97', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8416/3.5187 ACC: 0.6794                                                                                                     
Val loss: 3.8462 ACC: 0.67333                                                                                                    
Set dropout to 0.8727 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 98', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8396/3.5187 ACC: 0.68157                                                                                                    
Val loss: 3.8451 ACC: 0.67333                                                                                                    
Set dropout to 0.8734 (prop=1.001,1.0=1.001)
New Best Val loss: 3.8451                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 99', max=11977.0, style=ProgressStyle(description_w…

Train loss: 3.8421/3.85 ACC: 0.68015                                                                                                      
Val loss: 3.8441 ACC: 0.67556                                                                                                    
Set dropout to 0.8736 (prop=1.001,0.9999=1.0)
New Best Val loss: 3.8441                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 100', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8437/3.8516 ACC: 0.67912                                                                                                    
Val loss: 3.8443 ACC: 0.67185                                                                                                    
Set dropout to 0.8737 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 101', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8473/4.1792 ACC: 0.67385                                                                                                    
Val loss: 3.8428 ACC: 0.67481                                                                                                    
Set dropout to 0.8732 (prop=0.9993,0.9995=0.9994)
New Best Val loss: 3.8428                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 102', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8437/3.7796 ACC: 0.6778                                                                                                     
Val loss: 3.843 ACC: 0.67481                                                                                                     
Set dropout to 0.8731 (prop=0.9996,1.0=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 103', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.841/4.4665 ACC: 0.68204                                                                                                     
Val loss: 3.8428 ACC: 0.67852                                                                                                    
Set dropout to 0.8733 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 104', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8422/3.5187 ACC: 0.67827                                                                                                    
Val loss: 3.8432 ACC: 0.67704                                                                                                    
Set dropout to 0.8734 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 105', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8438/3.9317 ACC: 0.6778                                                                                                     
Val loss: 3.8438 ACC: 0.67259                                                                                                    
Set dropout to 0.8734 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 106', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8446/3.8522 ACC: 0.6779                                                                                                     
Val loss: 3.8439 ACC: 0.67333                                                                                                    
Set dropout to 0.8733 (prop=1.0,0.9999=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 107', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8418/3.861 ACC: 0.68072                                                                                                     
Val loss: 3.8435 ACC: 0.67481                                                                                                    
Set dropout to 0.8735 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 108', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8429/3.8521 ACC: 0.67855                                                                                                    
Val loss: 3.8432 ACC: 0.67407                                                                                                    
Set dropout to 0.8736 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 109', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8427/4.1853 ACC: 0.68006                                                                                                    
Val loss: 3.8421 ACC: 0.6763                                                                                                     
Set dropout to 0.8735 (prop=0.9998,1.0=0.9999)
New Best Val loss: 3.8421                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 110', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8397/3.5187 ACC: 0.68081                                                                                                    
Val loss: 3.8422 ACC: 0.67556                                                                                                    
Set dropout to 0.8738 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 111', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8398/3.5187 ACC: 0.68185                                                                                                    
Val loss: 3.8424 ACC: 0.67407                                                                                                    
Set dropout to 0.8741 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 112', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8414/3.5192 ACC: 0.67921                                                                                                    
Val loss: 3.8425 ACC: 0.67481                                                                                                    
Set dropout to 0.8742 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 113', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8392/3.8521 ACC: 0.68354                                                                                                    
Val loss: 3.8415 ACC: 0.67778                                                                                                    
Set dropout to 0.8745 (prop=1.0,1.0=1.0)
New Best Val loss: 3.8415                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 114', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8426/3.8521 ACC: 0.67874                                                                                                    
Val loss: 3.8407 ACC: 0.67926                                                                                                    
Set dropout to 0.8743 (prop=0.9998,0.9997=0.9998)
New Best Val loss: 3.8407                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 115', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8445/4.0023 ACC: 0.6779                                                                                                     
Val loss: 3.8417 ACC: 0.67407                                                                                                    
Set dropout to 0.874 (prop=0.9996,0.9997=0.9996)



HBox(children=(FloatProgress(value=0.0, description='Epoch 116', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8431/3.8521 ACC: 0.6778                                                                                                     
Val loss: 3.8414 ACC: 0.67556                                                                                                    
Set dropout to 0.8738 (prop=0.9995,1.0=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 117', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8389/3.852 ACC: 0.68288                                                                                                     
Val loss: 3.8419 ACC: 0.67407                                                                                                    
Set dropout to 0.8741 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 118', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8391/3.6201 ACC: 0.68213                                                                                                    
Val loss: 3.8416 ACC: 0.67407                                                                                                    
Set dropout to 0.8744 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 119', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8433/4.1853 ACC: 0.67921                                                                                                    
Val loss: 3.841 ACC: 0.67333                                                                                                     
Set dropout to 0.8741 (prop=0.9998,0.9996=0.9997)



HBox(children=(FloatProgress(value=0.0, description='Epoch 120', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8407/3.7958 ACC: 0.6811                                                                                                     
Val loss: 3.8403 ACC: 0.67778                                                                                                    
Set dropout to 0.8741 (prop=0.9998,1.0=0.9999)
New Best Val loss: 3.8403                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 121', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8391/4.1696 ACC: 0.68317                                                                                                    
Val loss: 3.8398 ACC: 0.67704                                                                                                    
Set dropout to 0.8741 (prop=0.9999,1.0=1.0)
New Best Val loss: 3.8398                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 122', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8396/3.852 ACC: 0.68194                                                                                                     
Val loss: 3.8383 ACC: 0.68                                                                                                       
Set dropout to 0.874 (prop=0.9996,1.0=0.9998)
New Best Val loss: 3.8383                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 123', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.836/3.6385 ACC: 0.6858                                                                                                      
Val loss: 3.8389 ACC: 0.67778                                                                                                    
Set dropout to 0.8743 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 124', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8446/3.8553 ACC: 0.67695                                                                                                    
Val loss: 3.839 ACC: 0.67926                                                                                                     
Set dropout to 0.8737 (prop=0.9993,0.9993=0.9993)



HBox(children=(FloatProgress(value=0.0, description='Epoch 125', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8412/3.9661 ACC: 0.67987                                                                                                    
Val loss: 3.8393 ACC: 0.67926                                                                                                    
Set dropout to 0.8735 (prop=0.9994,1.0=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 126', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8417/3.8521 ACC: 0.67959                                                                                                    
Val loss: 3.8409 ACC: 0.6763                                                                                                     
Set dropout to 0.8734 (prop=0.9998,1.0=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 127', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8359/3.8521 ACC: 0.68533                                                                                                    
Val loss: 3.8402 ACC: 0.67778                                                                                                    
Set dropout to 0.8739 (prop=1.001,1.001=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 128', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8366/3.5191 ACC: 0.68364                                                                                                    
Val loss: 3.8401 ACC: 0.67481                                                                                                    
Set dropout to 0.8743 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 129', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8373/3.8521 ACC: 0.68477                                                                                                    
Val loss: 3.8402 ACC: 0.67556                                                                                                    
Set dropout to 0.8746 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 130', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8393/3.8508 ACC: 0.68326                                                                                                    
Val loss: 3.8402 ACC: 0.6763                                                                                                     
Set dropout to 0.8747 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 131', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8369/3.5187 ACC: 0.68439                                                                                                    
Val loss: 3.8398 ACC: 0.67556                                                                                                    
Set dropout to 0.875 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 132', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8418/3.8515 ACC: 0.6811                                                                                                     
Val loss: 3.8401 ACC: 0.67556                                                                                                    
Set dropout to 0.8748 (prop=1.0,0.9996=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 133', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8389/3.5189 ACC: 0.68279                                                                                                    
Epoch   133: reducing learning rate of group 0 to 5.4150e-03.                                                                    

Set dropout to 0.8749 (prop=0.9999,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 134', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8373/3.8508 ACC: 0.68486                                                                                                    
Val loss: 3.84 ACC: 0.67704                                                                                                      
Set dropout to 0.8752 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 135', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8386/3.5187 ACC: 0.6826                                                                                                     
Val loss: 3.8393 ACC: 0.67926                                                                                                    
Set dropout to 0.8753 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 136', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8379/3.8521 ACC: 0.68335                                                                                                    
Val loss: 3.8397 ACC: 0.67926                                                                                                    
Set dropout to 0.8755 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 137', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8411/3.8521 ACC: 0.68053                                                                                                    
Val loss: 3.8392 ACC: 0.68                                                                                                       
Set dropout to 0.8752 (prop=0.9998,0.9997=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 138', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8379/3.8517 ACC: 0.68467                                                                                                    
Val loss: 3.8392 ACC: 0.67852                                                                                                    
Set dropout to 0.8754 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 139', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8361/3.8516 ACC: 0.68627                                                                                                    
Val loss: 3.8393 ACC: 0.67704                                                                                                    
Set dropout to 0.8758 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 140', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8423/3.8501 ACC: 0.67799                                                                                                    
Val loss: 3.8394 ACC: 0.67926                                                                                                    
Set dropout to 0.8754 (prop=0.9998,0.9995=0.9996)



HBox(children=(FloatProgress(value=0.0, description='Epoch 141', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8373/3.5187 ACC: 0.6827                                                                                                     
Val loss: 3.8391 ACC: 0.67852                                                                                                    
Set dropout to 0.8756 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 142', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8364/3.5187 ACC: 0.68373                                                                                                    
Val loss: 3.8395 ACC: 0.67852                                                                                                    
Set dropout to 0.876 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 143', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8395/3.8522 ACC: 0.68222                                                                                                    
Val loss: 3.8391 ACC: 0.68                                                                                                       
Set dropout to 0.876 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 144', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8375/3.5187 ACC: 0.68335                                                                                                    
Epoch   144: reducing learning rate of group 0 to 5.1442e-03.                                                                    

Set dropout to 0.8761 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 145', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.839/3.5194 ACC: 0.68213                                                                                                     
Val loss: 3.8391 ACC: 0.68074                                                                                                                                                                                                        
Set dropout to 0.8762 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 146', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8375/4.1854 ACC: 0.68524                                                                                                    
Val loss: 3.8384 ACC: 0.68148                                                                                                    
Set dropout to 0.8763 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 147', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.84/4.1854 ACC: 0.68241                                                                                                      
Val loss: 3.8367 ACC: 0.6837                                                                                                     
Set dropout to 0.8759 (prop=0.9994,0.9998=0.9996)
New Best Val loss: 3.8367                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 148', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8367/3.8518 ACC: 0.68401                                                                                                    
Val loss: 3.8372 ACC: 0.68148                                                                                                    
Set dropout to 0.8759 (prop=0.9999,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 149', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8384/3.8519 ACC: 0.68429                                                                                                    
Val loss: 3.8361 ACC: 0.68296                                                                                                    
Set dropout to 0.8757 (prop=0.9995,0.9999=0.9997)
New Best Val loss: 3.8361                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 150', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8366/3.5187 ACC: 0.68392                                                                                                    
Val loss: 3.8376 ACC: 0.68074                                                                                                    
Set dropout to 0.8758 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 151', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.835/3.5187 ACC: 0.6858                                                                                                      
Val loss: 3.8371 ACC: 0.68                                                                                                       
Set dropout to 0.876 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 152', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8378/4.1846 ACC: 0.68458                                                                                                    
Val loss: 3.8363 ACC: 0.68296                                                                                                    
Set dropout to 0.8759 (prop=0.9998,0.9998=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 153', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8347/3.8652 ACC: 0.68825                                                                                                    
Val loss: 3.8344 ACC: 0.68593                                                                                                    
Set dropout to 0.8758 (prop=0.9997,1.0=1.0)
New Best Val loss: 3.8344                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 154', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8337/3.5187 ACC: 0.68806                                                                                                    
Val loss: 3.8363 ACC: 0.68296                                                                                                    
Set dropout to 0.8761 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 155', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8361/3.5233 ACC: 0.68505                                                                                                    
Val loss: 3.8373 ACC: 0.68                                                                                                       
Set dropout to 0.8762 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 156', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8364/3.8503 ACC: 0.68542                                                                                                    
Val loss: 3.8369 ACC: 0.67852                                                                                                    
Set dropout to 0.8763 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 157', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8406/3.8518 ACC: 0.68119                                                                                                    
Val loss: 3.8377 ACC: 0.67852                                                                                                    
Set dropout to 0.876 (prop=0.9997,0.9995=0.9996)



HBox(children=(FloatProgress(value=0.0, description='Epoch 158', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8343/3.8503 ACC: 0.68806                                                                                                    
Val loss: 3.8384 ACC: 0.67704                                                                                                    
Set dropout to 0.8764 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 159', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8325/3.8524 ACC: 0.68947                                                                                                    
Val loss: 3.8377 ACC: 0.67778                                                                                                    
Set dropout to 0.877 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 160', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.835/3.8521 ACC: 0.68731                                                                                                     
Val loss: 3.837 ACC: 0.68                                                                                                        
Set dropout to 0.8772 (prop=1.001,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 161', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.838/3.8762 ACC: 0.68288                                                                                                     
Val loss: 3.8381 ACC: 0.67852                                                                                                    
Set dropout to 0.8773 (prop=1.0,0.9996=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 162', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.837/3.8467 ACC: 0.68561                                                                                                     
Val loss: 3.8382 ACC: 0.6763                                                                                                     
Set dropout to 0.8774 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 163', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8357/3.852 ACC: 0.68599                                                                                                     
Val loss: 3.8383 ACC: 0.67852                                                                                                    
Set dropout to 0.8777 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 164', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.837/4.1727 ACC: 0.68618                                                                                                     
Epoch   164: reducing learning rate of group 0 to 4.8870e-03.                                                                    

Set dropout to 0.8777 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 165', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8367/3.852 ACC: 0.68589                                                                                                     
Val loss: 3.8376 ACC: 0.68074                                                                                                    
Set dropout to 0.8778 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 166', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8363/3.8521 ACC: 0.6842                                                                                                     
Val loss: 3.8373 ACC: 0.68                                                                                                       
Set dropout to 0.8779 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 167', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8379/4.5181 ACC: 0.68561                                                                                                    
Val loss: 3.8378 ACC: 0.67852                                                                                                    
Set dropout to 0.8779 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 168', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8387/3.7586 ACC: 0.68345                                                                                                    
Val loss: 3.8383 ACC: 0.67704                                                                                                    
Set dropout to 0.8779 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 169', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8365/3.8518 ACC: 0.68561                                                                                                    
Val loss: 3.838 ACC: 0.67778                                                                                                     
Set dropout to 0.878 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 170', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8368/3.5188 ACC: 0.68307                                                                                                    
Val loss: 3.8373 ACC: 0.68                                                                                                       
Set dropout to 0.8781 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 171', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8363/3.5187 ACC: 0.6858                                                                                                     
Val loss: 3.8382 ACC: 0.67778                                                                                                    
Set dropout to 0.8783 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 172', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8352/3.5189 ACC: 0.6858                                                                                                     
Val loss: 3.838 ACC: 0.67926                                                                                                     
Set dropout to 0.8786 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 173', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8354/3.8552 ACC: 0.68702                                                                                                    
Val loss: 3.8372 ACC: 0.68074                                                                                                    
Set dropout to 0.8788 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 174', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8387/3.9106 ACC: 0.6842                                                                                                     
Val loss: 3.8378 ACC: 0.68074                                                                                                    
Set dropout to 0.8787 (prop=1.0,0.9997=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 175', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8345/3.5187 ACC: 0.68731                                                                                                    
Epoch   175: reducing learning rate of group 0 to 4.6427e-03.                                                                    

Set dropout to 0.8791 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 176', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8377/3.5261 ACC: 0.68448                                                                                                    
Val loss: 3.8373 ACC: 0.67926                                                                                                    
Set dropout to 0.8791 (prop=1.0,0.9998=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 177', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8351/3.5195 ACC: 0.68674                                                                                                    
Val loss: 3.8381 ACC: 0.67704                                                                                                    
Set dropout to 0.8794 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 178', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8401/3.852 ACC: 0.68147                                                                                                     
Val loss: 3.8376 ACC: 0.67778                                                                                                    
Set dropout to 0.8791 (prop=0.9998,0.9996=0.9997)



HBox(children=(FloatProgress(value=0.0, description='Epoch 179', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8364/3.5203 ACC: 0.68477                                                                                                    
Val loss: 3.8377 ACC: 0.67778                                                                                                    
Set dropout to 0.8793 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 180', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8353/3.8521 ACC: 0.68674                                                                                                    
Val loss: 3.8377 ACC: 0.68074                                                                                                    
Set dropout to 0.8796 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 181', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8355/3.5201 ACC: 0.68477                                                                                                    
Val loss: 3.8374 ACC: 0.68074                                                                                                    
Set dropout to 0.8798 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 182', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8394/3.8521 ACC: 0.68138                                                                                                    
Val loss: 3.8373 ACC: 0.68148                                                                                                    
Set dropout to 0.8795 (prop=0.9998,0.9996=0.9997)



HBox(children=(FloatProgress(value=0.0, description='Epoch 183', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8318/3.6276 ACC: 0.68881                                                                                                    
Val loss: 3.8384 ACC: 0.67852                                                                                                    
Set dropout to 0.8803 (prop=1.001,1.001=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 184', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8374/3.9685 ACC: 0.68524                                                                                                    
Val loss: 3.8376 ACC: 0.68                                                                                                       
Set dropout to 0.8803 (prop=1.0,0.9997=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 185', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8344/3.5187 ACC: 0.68693                                                                                                    
Val loss: 3.8365 ACC: 0.68222                                                                                                    
Set dropout to 0.8806 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 186', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.839/3.8528 ACC: 0.68345                                                                                                     
Epoch   186: reducing learning rate of group 0 to 4.4106e-03.                                                                    

Set dropout to 0.8804 (prop=1.0,0.9996=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 187', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.836/3.5187 ACC: 0.68458                                                                                                     
Val loss: 3.8386 ACC: 0.67926                                                                                                    
Set dropout to 0.8807 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 188', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8364/3.8521 ACC: 0.68627                                                                                                    
Val loss: 3.8382 ACC: 0.67926                                                                                                    
Set dropout to 0.8809 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 189', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8327/3.5187 ACC: 0.68891                                                                                                    
Val loss: 3.8379 ACC: 0.67778                                                                                                    
Set dropout to 0.8815 (prop=1.001,1.0=1.001)



HBox(children=(FloatProgress(value=0.0, description='Epoch 190', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8386/4.1751 ACC: 0.68561                                                                                                    
Val loss: 3.8377 ACC: 0.68                                                                                                       
Set dropout to 0.8814 (prop=1.0,0.9995=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 191', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8368/3.8516 ACC: 0.68505                                                                                                    
Val loss: 3.8371 ACC: 0.68                                                                                                       
Set dropout to 0.8815 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 192', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.837/3.5187 ACC: 0.68401                                                                                                     
Val loss: 3.8376 ACC: 0.68074                                                                                                    
Set dropout to 0.8816 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 193', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8402/4.1843 ACC: 0.68298                                                                                                    
Val loss: 3.8384 ACC: 0.6763                                                                                                     
Set dropout to 0.8813 (prop=0.9999,0.9997=0.9998)



HBox(children=(FloatProgress(value=0.0, description='Epoch 194', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8361/4.1849 ACC: 0.68721                                                                                                    
Val loss: 3.8377 ACC: 0.67926                                                                                                    
Set dropout to 0.8815 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 195', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8359/3.5187 ACC: 0.68486                                                                                                    
Val loss: 3.8378 ACC: 0.67852                                                                                                    
Set dropout to 0.8818 (prop=1.0,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 196', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8341/3.8501 ACC: 0.68834                                                                                                    
Val loss: 3.8371 ACC: 0.68074                                                                                                    
Set dropout to 0.8821 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 197', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8385/4.2086 ACC: 0.68552                                                                                                    
Epoch   197: reducing learning rate of group 0 to 4.1900e-03.                                                                    

Set dropout to 0.882 (prop=1.0,0.9996=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 198', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8343/3.5187 ACC: 0.68796                                                                                                    
Val loss: 3.8374 ACC: 0.67852                                                                                                    
Set dropout to 0.8824 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 199', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8362/3.5188 ACC: 0.68401                                                                                                    
Val loss: 3.8374 ACC: 0.67926                                                                                                    
Set dropout to 0.8825 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 200', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8364/3.5187 ACC: 0.68364                                                                                                    
Val loss: 3.838 ACC: 0.67778                                                                                                     
Set dropout to 0.8827 (prop=1.0,0.9999=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 201', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.835/3.8521 ACC: 0.68759                                                                                                     
Val loss: 3.8378 ACC: 0.68                                                                                                       
Set dropout to 0.883 (prop=1.001,1.0=1.0)



HBox(children=(FloatProgress(value=0.0, description='Epoch 202', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8388/4.1849 ACC: 0.68373                                                                                                    
Val loss: 3.8379 ACC: 0.6763                                                                                                     
Set dropout to 0.8829 (prop=1.0,0.9997=0.9999)



HBox(children=(FloatProgress(value=0.0, description='Epoch 203', max=11977.0, style=ProgressStyle(description_…

Train loss: 3.8355/3.9683 ACC: 0.6882                                                                                                     

In [None]:
ab = copy.deepcopy(best_model).to(device)
loss_total = 0
correct_t = 0
total_t = 0
dl_test = DataLoader(list(zip(fold.X_test, y_test)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=2)
for i, (terms_idx_t, docs_offsets_t, y_t) in enumerate(dl_test):
    terms_idx_t    = terms_idx_t.to( device )
    docs_offsets_t = docs_offsets_t.to( device )
    y_t            = y_t.to( device )

    pred_docs_t,weigths = ab( terms_idx_t, docs_offsets_t )
    pred_docs_t = F.softmax(pred_docs_t)

    y_pred_t    = pred_docs_t.argmax(axis=1)
    correct_t  += (y_pred_t == y_t).sum().item()
    total_t    += len(y_t)
    loss_total += loss_func_cel(pred_docs_t, y_t)

print(f'Test loss: {loss_total.item()/(i+1):.5} ACC: {correct_t/total_t:.5}', end=f"{' '*100}\r")

In [None]:
"""
acm ####################################################################################
Train loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    
Val loss: 1.7718 ACC: 0.77475                                                                                                    
New Best Val loss: 1.7718                                                                                                    
Test loss: 1.7678 ACC: 0.78557  79.92

Train loss: 1.7562/1.7777 ACC: 0.78902                                                                                                    
Val loss: 1.7574 ACC: 0.78717                                                                                                    
Set dropout to 0.8584 (prop=0.9964,1.004=1.0)
New Best Val loss: 1.7574                                                                                                
Test loss: 1.7665 ACC: 0.77836                                                                                                    


20ng ####################################################################################
Train loss: 2.0907/2.0787 ACC: 0.98845                                                                                                    
Val loss: 2.1869 ACC: 0.90803                                                                                                    
New Best Val loss: 2.1869                                                                                                    
Test loss: 2.178 ACC: 0.91068   92.65

reut ####################################################################################
Train loss: 3.7735/3.5191 ACC: 0.74734                                                                                                    
Val loss: 3.8554 ACC: 0.6763                                                                                                    
New Best Val loss: 3.8554                                                                                                    
Test loss: 3.8493 ACC: 0.6837  72.67

Train loss: 3.7489/3.8717 ACC: 0.77265                                                                                                    
Val loss: 3.8084 ACC: 0.71037                                                                                                    
Adjusting learning rate of group 0 to 5.1046e-03.
New Best Val loss: 3.8084 

webkb ####################################################################################
Train loss: 1.2228/1.2037 ACC: 0.9504                                                                                                     
Val loss: 1.3787 ACC: 0.80316                                                                                                    
New Best Val loss: 1.3787                                                                                                    
Test loss: 1.3857 ACC: 0.78858   81.53

"""

In [None]:
class AttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5):
        super(AttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.mask       = Mask()
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.dt_dir_map = nn.Linear(hiddens, hiddens)
        self.drop       = nn.Dropout(drop)
        self.ma_term    = nn.MultiheadAttention(hiddens, 1)
        self.fc         = nn.Linear(hiddens, nclass)
        self.initrange  = initrange 
        self.init_weights()
    def forward(self, terms_idx, docs_offsets):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h      = self.dt_emb( x_packed )
        dt_h      = self.drop(dt_h)
        dir_dt_h  = self.dt_dir_map( dt_h )

        weights = torch.bmm( dt_h, dir_dt_h.transpose( 1, 2 ) )
        weights = self.mask(weights)
        
        weights_disc = (weights * pad_mask)
        weights_disc = weights_disc.sum(axis=1)
        weights_disc = F.softmax(weights_disc, dim=1)
        weights_disc = weights_disc.view( *weights_disc.shape, 1 )
        
        attn_mask = weights != 0
        attn_mask = attn_mask.logical_and( pad_mask ).logical_not()
        
        dt_h     = dt_h.transpose(0,1)
        dir_dt_h = dir_dt_h.transpose(0,1)
        docs_att, weigths_att = self.ma_term( dt_h, dir_dt_h, dt_h,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )

        weigths_att = torch.where(torch.isnan(weigths_att), torch.zeros_like(weigths_att), weigths_att)
        weigths_att = (weigths_att * pad_mask)
        weigths_att = weigths_att.sum(axis=1)
        weigths_att = F.softmax(weigths_att, dim=1)
        weigths_att = weigths_att.view( *weigths_att.shape, 1 )
        
        weigths = weights_disc + weigths_att

        docs_att = docs_att.transpose(0,1)
        docs_att = torch.where(torch.isnan(docs_att), torch.zeros_like(docs_att), docs_att)
        
        docs_h = docs_att * weigths
        docs_h = docs_h.sum(axis=1)
        docs_h = docs_h / bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        docs_h = torch.where(torch.isnan(docs_h), torch.zeros_like(docs_h), docs_h)
        docs_h = self.drop(docs_h)
        return self.fc(docs_h)
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.dt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)
        self.ma_term.in_proj_weight.data.uniform_(-self.initrange, self.initrange)

In [None]:
y_pred_t      = pred_docs_t.argmax(axis=1)
correct_t     = (y_pred_t == y_t).sum().item()
total_t       = len(y_t)
correct_t/total_t

In [None]:
docs_offsets_t

In [None]:
terms_idx_t[:docs_offsets_t[batch_size]]

In [None]:
batch_off_test  = docs_offsets_t[batch_size]
batch_tidx_test = terms_idx_t[:batch_off_test]
h_terms_test    = sc.tt_emb( batch_tidx_test )
dirh_terms_test = sc.tt_dir_map( h_terms_test )

W = torch.matmul( h_terms_test, dirh_terms_test.T )
W = F.leaky_relu( W, negative_slope=sc.mask.negative_slope)
W = F.sigmoid(W)
W

In [None]:
k = [ batch_tidx_test[ docs_offsets_t[i-1]:docs_offsets_t[i] ] for i in range(1, batch_size) ]
k.append( batch_tidx_test[ docs_offsets_t[batch_size-1]:docs_offsets_t[batch_size] ] )
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )
len(k)

In [None]:
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )

In [None]:
tt_emb.transpose(0,1)

In [None]:
a = [terms_idx, terms_idx]
torch.stack(a)

In [None]:
F.softmax(pred_docs_t).argmax(axis=1)

In [None]:
y_val

In [None]:
(y_t == F.softmax(pred_docs_t).argmax(axis=1)).sum().item()/y_t.shape[0]

In [None]:
terms_idx_t, docs_offsets_t

In [None]:
shifts = sc._get_shift_(docs_offsets_t, terms_idx_t.shape[0])

In [None]:
zipado = zip(docs_offsets_t, shifts)
#next(zipado)
start,size = next(zipado)

In [None]:
w = sc.tt_emb( terms_idx_t[start:start+size] )
w1 = sc.tt_dir_map( w )
w = torch.matmul( w, w1.T )
w = F.leaky_relu( w, negative_slope=sc.negative_slope)
w = F.sigmoid(w)
#w = F.tanh(w)
#w = F.relu(w)
#w = w.mean(axis=1)
#w = F.softmax(w)
w,w1

In [None]:
w.round().sum() / (w.shape[0]*w.shape[1])

In [None]:
inv_mapper = { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
terms_idx[start:start+size]

In [None]:
fold.X_val[0]

In [None]:
bla = w.mean(axis=1)
bla = F.softmax(bla)
#bla = bla/torch.clamp(bla.sum(), 0.0001)
bla

In [None]:
[ (i, tid.item(),inv_mapper[tid.item()], wei.item()) for i, (tid, wei) in enumerate(zip(terms_idx[start:start+size], bla)) ]

In [None]:
w

In [None]:
w.mean(axis=1)

In [None]:
w.shape

In [None]:
norm = nn.BatchNorm1d(num_features=1).to(device)

bla2 = norm(bla.view(-1, 1)).squeeze()
bla2 = F.sigmoid(bla2)
bla2

In [None]:
1

In [None]:
ma = nn.MultiheadAttention(300, 300).to(device)
ma

In [None]:
torch.__version__

In [None]:
a,b = w.shape
w_ = w.view(a,1,b)
w1_ = w1.view(a,1,b)

attn_output = ma(w_, w1_, w_, need_weights=False)

attn_output.view(a,b)
attn_output_weights.view(a,a)

In [None]:
attn_output.view(a,b).shape

In [None]:
attn_output_weights.view(a,a)

In [None]:
F.softmax(torch.Tensor([[0.0718, 0.0716, 0.0712, 0.0714, 0.0721, 0.0710, 0.0712, 0.0714, 0.0710,
         0.0719, 0.0711, 0.0712, 0.0718, 0.0714],
        [0.0722, 0.0709, 0.0710, 0.0709, 0.0728, 0.0723, 0.0709, 0.0709, 0.0719,
         0.0712, 0.0709, 0.0707, 0.0725, 0.0707],
        [0.0711, 0.0715, 0.0710, 0.0713, 0.0710, 0.0712, 0.0718, 0.0713, 0.0709,
         0.0725, 0.0716, 0.0719, 0.0710, 0.0719],
        [0.0721, 0.0717, 0.0714, 0.0713, 0.0717, 0.0714, 0.0711, 0.0710, 0.0713,
         0.0717, 0.0710, 0.0712, 0.0720, 0.0711],
        [0.0722, 0.0711, 0.0710, 0.0710, 0.0737, 0.0716, 0.0709, 0.0705, 0.0714,
         0.0719, 0.0707, 0.0707, 0.0731, 0.0702],
        [0.0714, 0.0716, 0.0708, 0.0711, 0.0718, 0.0713, 0.0712, 0.0717, 0.0712,
         0.0724, 0.0713, 0.0712, 0.0716, 0.0714],
        [0.0712, 0.0714, 0.0713, 0.0713, 0.0722, 0.0716, 0.0709, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0713, 0.0716, 0.0713],
        [0.0716, 0.0707, 0.0712, 0.0714, 0.0717, 0.0715, 0.0714, 0.0715, 0.0712,
         0.0721, 0.0711, 0.0714, 0.0717, 0.0715],
        [0.0717, 0.0713, 0.0708, 0.0710, 0.0725, 0.0717, 0.0714, 0.0709, 0.0712,
         0.0712, 0.0712, 0.0717, 0.0720, 0.0714],
        [0.0709, 0.0712, 0.0713, 0.0712, 0.0713, 0.0713, 0.0714, 0.0719, 0.0712,
         0.0724, 0.0715, 0.0715, 0.0717, 0.0713],
        [0.0715, 0.0716, 0.0712, 0.0708, 0.0725, 0.0717, 0.0712, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0706, 0.0718, 0.0712],
        [0.0726, 0.0711, 0.0713, 0.0706, 0.0737, 0.0721, 0.0709, 0.0707, 0.0714,
         0.0708, 0.0706, 0.0705, 0.0730, 0.0707],
        [0.0718, 0.0711, 0.0714, 0.0715, 0.0725, 0.0707, 0.0707, 0.0711, 0.0712,
         0.0728, 0.0711, 0.0713, 0.0719, 0.0709],
        [0.0712, 0.0716, 0.0713, 0.0712, 0.0717, 0.0707, 0.0708, 0.0721, 0.0709,
         0.0728, 0.0713, 0.0716, 0.0711, 0.0715]]).sum(axis=0))

In [None]:
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout1=0.1, dropout2=0.1, negative_slope=99,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.undirected_map = nn.Linear(hidden_l, hidden_l)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop1 = nn.Dropout(dropout1)
        self.drop2 = nn.Dropout(dropout2)
        
        self.norm = nn.BatchNorm1d(1)
        
        self.initrange = initrange
        self.nclass = nclass
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        
        terms_h1 = self.tt_emb(terms_idxs)
        terms_h1 = self.drop1(terms_h1)
        
        terms_h2 = self.undirected_map( terms_h1 )
        #terms_h2 = self.drop1( terms_h2 )
        for start,size in zip(docs_offsets, shifts):
            w  = terms_h1[start:start+size]
            w1 = terms_h2[start:start+size]
            w = torch.matmul( w, w1.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w-5.5)
            w = w.mean(axis=1)
            w = F.softmax(w)
            #w = w / torch.clamp(w.sum(), 0.0001)
            weights.append( w )
        
        weights = torch.cat(weights)
        #weights = self.norm(weights.view(-1, 1)).squeeze()
        #weights = F.sigmoid(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop2( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])