In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset

from tqdm.notebook import tqdm
from TGA.utils import preprocessor
import copy

from time import time
import numpy as np
from itertools import repeat
from collections import Counter
from segtok import tokenizer as tk

from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS as stop_words
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin

In [2]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [4]:
dataset = Dataset('/home/datasets/acm/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 19907)

In [5]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [6]:
class Tokenizer(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, stopwords='remove', model='list', lan='english', k=25, verbose=False):
        super(Tokenizer, self).__init__()
        self.mindf = mindf
        self.le = LabelEncoder()
        self.verbose = verbose
        self.stopwords = stopwords
        self.stopwordsSet = stop_words
        self.lan = lan
        self.k = k
        self.model = model
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor)#.build_analyzer()
        #self.analyzer = tk.web_tokenizer
    
    def fit(self, X, y):
        self.N = len(X)
        y = self.le.fit_transform( y )
        self.n_class = len(self.le.classes_)
        

        self.term_freqs = Counter()
        local_analyzer  = self.analyzer.build_analyzer()
        docs = map(local_analyzer, X)
        for doc_in_terms in tqdm(docs, total=self.N, disable=not self.verbose):
            doc_in_terms = list(map( self._filter_fit_, doc_in_terms ))
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {'<BLANK>': 0}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self._get_idx_(term) for term in self.term_freqs.keys() if self._isrel_(term) }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.node_mapper['<BLANK>'] = 0
        self.vocab_size = len(self.node_mapper)
        
        if self.model == 'topk' or self.model == 'sample':
            from multiprocessing import cpu_count
            from sklearn.ensemble import RandomForestClassifier
            
            self.rf = RandomForestClassifier(n_jobs=cpu_count(), verbose=self.verbose>0)
            self.rf.fit(self.analyzer.fit_transform(X), y)
            imps = tokenizer.rf.feature_importances_
            self.fi_ = { term: imps[idterm] for (term, idterm) in tokenizer.analyzer.vocabulary_.items() }
            self.fi_['<BLANK>'] = 0.
            self.fi_['<UNK>'] = 0.
            
        return self
    def _isrel_(self, term):
        if self.stopwords == 'remove' and term in self.stopwordsSet:
            return False
        # put here your filter_functions
        return True
    def _get_idx_(self, term):
        # put here your idx_set_functions
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return self.node_mapper.setdefault('<STPW>', len(self.node_mapper))
        return self.node_mapper.setdefault(term, len(self.node_mapper))
    def _filter_transform_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        if term not in self.node_mapper:
            return '<UNK>'
        return term
    def _filter_fit_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        return term
    def _model_(self, doc):
        if self.model == 'set':
            return set(doc)
        if self.model == 'topk':
            doc = np.array(list(set(doc)))
            weigths = np.array([ self.fi_[t] for t in doc ])
            doc = doc[(-weigths).argsort()[:self.k]]
            return doc
        if self.model == 'sample':
            doc = np.array(list(set(doc)))
            if len(doc) > self.k:
                weigths = np.array([ self.fi_[t] for t in doc ])
                weigths = softmax(weigths)
                doc = np.random.choice(doc, size=self.k, replace=False, p=weigths)
            return doc
        return list(doc)
    def transform(self, X, verbose=None):
        verbose = verbose if verbose is not None else self.verbose
        n = len(X)
        doc_off = [0]
        terms_idx = []
        local_analyzer  = self.analyzer.build_analyzer()
        for i,doc_in_terms in tqdm(enumerate(map(local_analyzer, X)), total=n, disable=not verbose):
            doc_in_terms = filter( self._isrel_, doc_in_terms )
            doc_in_terms = map( self._filter_transform_, doc_in_terms )
            doc_in_terms = self._model_(doc_in_terms)
            doc_in_terms = [ self.node_mapper[tid] for tid in doc_in_terms ]
            if self.model == 'sorted':
                doc_in_terms = sorted(doc_in_terms)
            doc_off.append( len(doc_in_terms) )
            terms_idx.extend( doc_in_terms )
        return np.array( terms_idx ), np.array(doc_off)[:-1].cumsum()

In [7]:
tokenizer = Tokenizer(mindf=1, stopwords='keep', model='sample', k=128, verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)

  0%|          | 0/19907 [00:00<?, ?it/s]

[Parallel(n_jobs=12)]: Using backend ThreadingBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done  26 tasks      | elapsed:    3.4s
[Parallel(n_jobs=12)]: Done 100 out of 100 | elapsed:    9.8s finished


Tokenizer(k=128, mindf=1, model='sample', stopwords='keep', verbose=True)

In [8]:
doc_in_terms = tokenizer.analyzer.build_analyzer()(fold.X_val[0])
doc_in_terms = filter( tokenizer._isrel_, doc_in_terms )
doc_in_terms = map( tokenizer._filter_transform_, doc_in_terms )
#doc_in_terms = tokenizer._model_(doc_in_terms)

doc = np.array(list(set(doc_in_terms)))

In [9]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )
y_test  = tokenizer.le.transform( fold.y_test )

In [10]:
tokenizer.node_mapper['<BLANK>'], fold.X_val[0]

(0,
 'how do computer science lecturers create modules? (poster) john traxler  \n')

In [11]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [12]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=2.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
        self.sig = nn.Sigmoid()
    def forward(self, h):
        w = F.leaky_relu( h, negative_slope=self.negative_slope)
        w = self.sig(w-self.kappa)
        return w

In [13]:
help(torch.embedding_bag)

Help on built-in function embedding_bag:

embedding_bag(...)



In [14]:
class SimpleAttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5, negative_slope=99.):
        super(SimpleAttentionBag, self).__init__()
        self.hiddens        = hiddens
        self.dt_emb         = nn.Embedding(vocab_size, hiddens)
        self.tt_source_map  = nn.Linear(hiddens, hiddens)
        self.tt_target_map  = nn.Linear(hiddens, hiddens)
        self.fc             = nn.Linear(hiddens, nclass)
        self.initrange      = initrange 
        self.negative_slope = negative_slope
        self.drop           = nn.Dropout(drop)
        self.drop_          = drop
        self.sig            = nn.Sigmoid()
        self.init_weights()
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h     = self.dt_emb( x_packed )
        dt_h     = F.dropout( dt_h, p=self.drop_, training=self.training )
        
        source_h = self.tt_source_map( dt_h )
        source_h = F.tanh(source_h)
        #source_h = F.leaky_relu(source_h, negative_slope=self.negative_slope)
        source_h = F.dropout( source_h, p=self.drop_, training=self.training )
        
        target_h = self.tt_target_map( dt_h )
        target_h = F.tanh(target_h)
        #target_h = F.leaky_relu(target_h, negative_slope=self.negative_slope)
        target_h = F.dropout( target_h, p=self.drop_, training=self.training )

        co_weights = torch.bmm( source_h, target_h.transpose( 1, 2 ) )
        co_weights = F.leaky_relu( co_weights, negative_slope=self.negative_slope)
        
        co_weights[pad_mask.logical_not()] = float('-inf') # Set the 3D-pad mask values to -inf (=0 in sigmoid)
        co_weights = F.sigmoid(co_weights)
        #co_weights = torch.where(torch.isnan(co_weights), torch.zeros_like(co_weights), co_weights) #replace nan to zero
        
        weights = co_weights.sum(axis=2) / doc_sizes
        weights[bx_packed] = float('-inf') # Set the 2D-pad mask values to -inf  (=0 in softmax)
        weights = F.softmax(weights, dim=1)
        weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
        weights = weights.view( *weights.shape, 1 )
        
        docs_h = dt_h * weights
        docs_h = docs_h.sum(axis=1)
        docs_h = F.dropout( docs_h, p=self.drop_, training=self.training )
        return self.fc(docs_h), weights, co_weights
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        #self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_source_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_target_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)

In [52]:
nepochs = 1000
max_epochs = 20
drop=0.3
device = torch.device('cuda:0')
batch_size = 64
k = 100

In [53]:
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f879c14c3b0>

In [54]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
ab = SimpleAttentionBag(tokenizer.vocab_size, 300, tokenizer.n_class, drop=drop).to( device )
tokenizer.k = k
optimizer = optim.AdamW( ab.parameters(), lr=6e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.95,
                                                       patience=10, verbose=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=.98, verbose=True)

In [55]:
best = 99999.
counter = 1
loss_val = 1.
eps = .9
dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=12)
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=12)
    loss_train  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        ab.train()
        tokenizer.model = 'sample'
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs,_,_ = ab( terms_idx, docs_offsets)
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_train += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            #ab.drop_ =  np.power((correct/total),loss_val)
            ab.drop_ =  np.power((correct/total),3.)
            
            toprint  = f"Train loss: {loss_train/(i+1):.5}/{loss.item():.5} "
            toprint += f'ACC: {correct/total:.5}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            del pred_docs, loss
            del terms_idx, docs_offsets, y
            del y_pred
        loss_train = loss_train/(i+1)
        print()
        print(f'drop: {ab.drop_:.5}')
        total = 0
        correct  = 0
        ab.eval()
        tokenizer.model = 'topk'
        with torch.no_grad():
            loss_val = 0.
            for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
                terms_idx    = terms_idx.to( device )
                docs_offsets = docs_offsets.to( device )
                y            = y.to( device )

                pred_docs,weights,co_weights = ab( terms_idx, docs_offsets )
                pred_docs   = F.softmax(pred_docs)

                y_pred      = pred_docs.argmax(axis=1)
                correct    += (y_pred == y).sum().item()
                total      += len(y)
                loss2       = loss_func_cel(pred_docs, y)
                loss_val   += loss2

                print(f'Val loss: {loss_val.item()/(i+1):.5} ACC: {correct/total:.5}', end=f"{' '*100}\r")
   
                pbar.update( len(y) )
            print()

            del terms_idx, docs_offsets, y
            del y_pred
            
            loss_val   = (loss_val/(i+1)).cpu()
            scheduler.step(loss_val)

            if best-loss_val > 0.0001 :
                best = loss_val.item()
                counter = 1
                print(f'New Best Val loss: {best:.5}', end=f"{' '*100}\n")
                best_model = copy.deepcopy(ab).to('cpu')
            elif counter > max_epochs:
                print(f'Best Val loss: {best:.5}', end=f"{' '*100}\n")
                break
            else:
                counter += 1
            del pred_docs, loss2

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.9769/1.9188 ACC: 0.61792                                                                                                     
drop: 0.23594
Val loss: 1.832 ACC: 0.73106                                                                                                     
New Best Val loss: 1.832                                                                                                    


Epoch 2:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.754/1.5847 ACC: 0.81072                                                                                                     
drop: 0.53286
Val loss: 1.8024 ACC: 0.75311                                                                                                    
New Best Val loss: 1.8024                                                                                                    


Epoch 3:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.7212/1.6644 ACC: 0.8383                                                                                                     
drop: 0.58911
Val loss: 1.7933 ACC: 0.75631                                                                                                    
New Best Val loss: 1.7933                                                                                                    


Epoch 4:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.705/1.5451 ACC: 0.84945                                                                                                     
drop: 0.61293
Val loss: 1.7853 ACC: 0.76954                                                                                                    
New Best Val loss: 1.7853                                                                                                    


Epoch 5:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.695/1.8232 ACC: 0.8602                                                                                                      
drop: 0.6365
Val loss: 1.7849 ACC: 0.76473                                                                                                    
New Best Val loss: 1.7849                                                                                                    


Epoch 6:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6891/1.543 ACC: 0.86392                                                                                                     
drop: 0.64479
Val loss: 1.7846 ACC: 0.76353                                                                                                    
New Best Val loss: 1.7846                                                                                                    


Epoch 7:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6829/1.543 ACC: 0.86819                                                                                                     
drop: 0.65439
Val loss: 1.7862 ACC: 0.75711                                                                                                    


Epoch 8:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6807/1.7649 ACC: 0.87025                                                                                                    
drop: 0.65906
Val loss: 1.7811 ACC: 0.76513                                                                                                    
New Best Val loss: 1.7811                                                                                                    


Epoch 9:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6777/1.5438 ACC: 0.87271                                                                                                    
drop: 0.66467
Val loss: 1.7861 ACC: 0.75752                                                                                                    


Epoch 10:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6715/1.5445 ACC: 0.87738                                                                                                    
drop: 0.6754
Val loss: 1.782 ACC: 0.76192                                                                                                     


Epoch 11:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6723/1.8613 ACC: 0.87738                                                                                                    
drop: 0.6754
Val loss: 1.7831 ACC: 0.76152                                                                                                    


Epoch 12:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6665/1.5433 ACC: 0.8814                                                                                                     
drop: 0.68473
Val loss: 1.7798 ACC: 0.76353                                                                                                    
New Best Val loss: 1.7798                                                                                                    


Epoch 13:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6665/1.6186 ACC: 0.88185                                                                                                    
drop: 0.68578
Val loss: 1.7823 ACC: 0.76393                                                                                                    


Epoch 14:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.666/1.8845 ACC: 0.8827                                                                                                      
drop: 0.68777
Val loss: 1.7825 ACC: 0.75912                                                                                                    


Epoch 15:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6642/1.8758 ACC: 0.88301                                                                                                    
drop: 0.68848
Val loss: 1.7821 ACC: 0.76032                                                                                                    


Epoch 16:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6622/1.8644 ACC: 0.88532                                                                                                    
drop: 0.6939
Val loss: 1.7815 ACC: 0.76032                                                                                                    


Epoch 17:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6606/1.543 ACC: 0.88692                                                                                                     
drop: 0.69769
Val loss: 1.7816 ACC: 0.76072                                                                                                    


Epoch 18:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6599/1.5953 ACC: 0.88662                                                                                                    
drop: 0.69697
Val loss: 1.7806 ACC: 0.75952                                                                                                    


Epoch 19:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6584/1.5498 ACC: 0.88833                                                                                                    
drop: 0.70101
Val loss: 1.7772 ACC: 0.76673                                                                                                    
New Best Val loss: 1.7772                                                                                                    


Epoch 20:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6587/1.8856 ACC: 0.88949                                                                                                    
drop: 0.70375
Val loss: 1.778 ACC: 0.76393                                                                                                     


Epoch 21:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6557/1.545 ACC: 0.89089                                                                                                     
drop: 0.70709
Val loss: 1.7798 ACC: 0.76192                                                                                                    


Epoch 22:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6564/1.5441 ACC: 0.89099                                                                                                    
drop: 0.70733
Val loss: 1.7788 ACC: 0.76633                                                                                                    


Epoch 23:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6556/1.5568 ACC: 0.89014                                                                                                    
drop: 0.7053
Val loss: 1.7796 ACC: 0.76473                                                                                                    


Epoch 24:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6555/1.543 ACC: 0.89124                                                                                                     
drop: 0.70793
Val loss: 1.7789 ACC: 0.76553                                                                                                    


Epoch 25:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6545/1.894 ACC: 0.8923                                                                                                      
drop: 0.71045
Val loss: 1.7797 ACC: 0.76273                                                                                                    


Epoch 26:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6529/1.543 ACC: 0.89335                                                                                                     
drop: 0.71297
Val loss: 1.7802 ACC: 0.76553                                                                                                    


Epoch 27:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6518/1.5432 ACC: 0.89386                                                                                                    
drop: 0.71417
Val loss: 1.7809 ACC: 0.76393                                                                                                    


Epoch 28:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6502/1.5451 ACC: 0.89572                                                                                                    
drop: 0.71864
Val loss: 1.7796 ACC: 0.76553                                                                                                    


Epoch 29:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6527/1.5433 ACC: 0.89406                                                                                                    
drop: 0.71465
Val loss: 1.7766 ACC: 0.76754                                                                                                    
New Best Val loss: 1.7766                                                                                                    


Epoch 30:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6528/1.8398 ACC: 0.89381                                                                                                    
drop: 0.71405
Val loss: 1.7784 ACC: 0.76473                                                                                                    


Epoch 31:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6526/1.5431 ACC: 0.8932                                                                                                     
drop: 0.71261
Val loss: 1.7767 ACC: 0.76513                                                                                                    


Epoch 32:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6511/1.5431 ACC: 0.89516                                                                                                    
drop: 0.71731
Val loss: 1.7783 ACC: 0.76553                                                                                                    


Epoch 33:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6499/1.543 ACC: 0.89531                                                                                                     
drop: 0.71767
Val loss: 1.7776 ACC: 0.76433                                                                                                    


Epoch 34:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6491/1.5431 ACC: 0.89667                                                                                                    
drop: 0.72094
Val loss: 1.7774 ACC: 0.76473                                                                                                    


Epoch 35:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6481/1.5432 ACC: 0.89717                                                                                                    
drop: 0.72215
Val loss: 1.7781 ACC: 0.76393                                                                                                    


Epoch 36:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6501/1.5433 ACC: 0.89602                                                                                                    
drop: 0.71936
Val loss: 1.777 ACC: 0.76794                                                                                                     


Epoch 37:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6487/1.543 ACC: 0.89662                                                                                                     
drop: 0.72082
Val loss: 1.7762 ACC: 0.76954                                                                                                    
New Best Val loss: 1.7762                                                                                                    


Epoch 38:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6484/1.5432 ACC: 0.89813                                                                                                    
drop: 0.72446
Val loss: 1.7749 ACC: 0.76874                                                                                                    
New Best Val loss: 1.7749                                                                                                    


Epoch 39:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6468/1.7646 ACC: 0.89883                                                                                                    
drop: 0.72616
Val loss: 1.7779 ACC: 0.76754                                                                                                    


Epoch 40:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6475/1.543 ACC: 0.89677                                                                                                     
drop: 0.72118
Val loss: 1.7757 ACC: 0.76834                                                                                                    


Epoch 41:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6457/1.5491 ACC: 0.89928                                                                                                    
drop: 0.72726
Val loss: 1.7761 ACC: 0.76794                                                                                                    


Epoch 42:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6459/1.5706 ACC: 0.89933                                                                                                    
drop: 0.72738
Val loss: 1.7753 ACC: 0.76633                                                                                                    


Epoch 43:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6454/1.543 ACC: 0.89948                                                                                                     
drop: 0.72774
Val loss: 1.7748 ACC: 0.77074                                                                                                    


Epoch 44:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6458/1.543 ACC: 0.89923                                                                                                     
drop: 0.72713
Val loss: 1.7731 ACC: 0.76954                                                                                                    
New Best Val loss: 1.7731                                                                                                    


Epoch 45:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6461/1.543 ACC: 0.89888                                                                                                     
drop: 0.72628
Val loss: 1.7715 ACC: 0.77074                                                                                                    
New Best Val loss: 1.7715                                                                                                    


Epoch 46:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6456/1.543 ACC: 0.89933                                                                                                     
drop: 0.72738
Val loss: 1.7707 ACC: 0.77275                                                                                                    
New Best Val loss: 1.7707                                                                                                    


Epoch 47:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6457/1.5431 ACC: 0.89978                                                                                                    
drop: 0.72848
Val loss: 1.7705 ACC: 0.77275                                                                                                    
New Best Val loss: 1.7705                                                                                                    


Epoch 48:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6451/1.5522 ACC: 0.90029                                                                                                    
drop: 0.7297
Val loss: 1.7724 ACC: 0.76994                                                                                                    


Epoch 49:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6438/1.543 ACC: 0.90184                                                                                                     
drop: 0.73349
Val loss: 1.7703 ACC: 0.77275                                                                                                    
New Best Val loss: 1.7703                                                                                                    


Epoch 50:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6453/1.7218 ACC: 0.90029                                                                                                    
drop: 0.7297
Val loss: 1.7732 ACC: 0.76874                                                                                                    


Epoch 51:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6457/1.8764 ACC: 0.90029                                                                                                    
drop: 0.7297
Val loss: 1.7727 ACC: 0.77034                                                                                                    


Epoch 52:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6441/1.5859 ACC: 0.90089                                                                                                    
drop: 0.73116
Val loss: 1.775 ACC: 0.76794                                                                                                     


Epoch 53:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6445/1.8908 ACC: 0.90154                                                                                                    
drop: 0.73275
Val loss: 1.7723 ACC: 0.76834                                                                                                    


Epoch 54:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6425/1.5643 ACC: 0.9024                                                                                                     
drop: 0.73484
Val loss: 1.7719 ACC: 0.77154                                                                                                    


Epoch 55:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6426/1.5443 ACC: 0.90285                                                                                                    
drop: 0.73594
Val loss: 1.7692 ACC: 0.77355                                                                                                    
New Best Val loss: 1.7692                                                                                                    


Epoch 56:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.643/1.543 ACC: 0.9025                                                                                                       
drop: 0.73508
Val loss: 1.7712 ACC: 0.77154                                                                                                    


Epoch 57:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6429/1.543 ACC: 0.90204                                                                                                     
drop: 0.73398
Val loss: 1.7709 ACC: 0.77515                                                                                                    


Epoch 58:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6429/1.543 ACC: 0.9027                                                                                                      
drop: 0.73557
Val loss: 1.7718 ACC: 0.77034                                                                                                    


Epoch 59:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6441/1.8573 ACC: 0.90265                                                                                                    
drop: 0.73545
Val loss: 1.7724 ACC: 0.77154                                                                                                    


Epoch 60:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.642/1.543 ACC: 0.9038                                                                                                       
drop: 0.73828
Val loss: 1.7697 ACC: 0.77475                                                                                                    


Epoch 61:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6412/1.5438 ACC: 0.9035                                                                                                     
drop: 0.73754
Val loss: 1.7699 ACC: 0.77275                                                                                                    


Epoch 62:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6434/1.543 ACC: 0.90144                                                                                                     
drop: 0.73251
Val loss: 1.7723 ACC: 0.77114                                                                                                    


Epoch 63:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6452/2.2097 ACC: 0.90174                                                                                                    
drop: 0.73324
Val loss: 1.7724 ACC: 0.77114                                                                                                    


Epoch 64:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6419/1.5441 ACC: 0.90345                                                                                                    
drop: 0.73742
Val loss: 1.7704 ACC: 0.77194                                                                                                    


Epoch 65:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6408/1.5431 ACC: 0.90451                                                                                                    
drop: 0.74
Val loss: 1.7695 ACC: 0.77034                                                                                                    


Epoch 66:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6414/1.543 ACC: 0.9032                                                                                                      
drop: 0.7368
Val loss: 1.7684 ACC: 0.77595                                                                                                    
New Best Val loss: 1.7684                                                                                                    


Epoch 67:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.641/1.6873 ACC: 0.90466                                                                                                     
drop: 0.74037
Val loss: 1.7688 ACC: 0.77154                                                                                                    


Epoch 68:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6433/1.8341 ACC: 0.9031                                                                                                     
drop: 0.73656
Val loss: 1.7704 ACC: 0.77315                                                                                                    


Epoch 69:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6417/1.8764 ACC: 0.90425                                                                                                    
drop: 0.73939
Val loss: 1.7682 ACC: 0.77355                                                                                                    
New Best Val loss: 1.7682                                                                                                    


Epoch 70:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6412/1.5487 ACC: 0.90405                                                                                                    
drop: 0.7389
Val loss: 1.7674 ACC: 0.77595                                                                                                    
New Best Val loss: 1.7674                                                                                                    


Epoch 71:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6406/1.543 ACC: 0.9041                                                                                                      
drop: 0.73902
Val loss: 1.7679 ACC: 0.77595                                                                                                    


Epoch 72:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6413/1.8764 ACC: 0.90451                                                                                                    
drop: 0.74
Val loss: 1.7672 ACC: 0.77715                                                                                                    
New Best Val loss: 1.7672                                                                                                    


Epoch 73:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6416/1.543 ACC: 0.90325                                                                                                     
drop: 0.73693
Val loss: 1.7669 ACC: 0.77555                                                                                                    
New Best Val loss: 1.7669                                                                                                    


Epoch 74:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6404/1.8473 ACC: 0.90516                                                                                                    
drop: 0.74161
Val loss: 1.7689 ACC: 0.77475                                                                                                    


Epoch 75:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6407/1.8759 ACC: 0.90546                                                                                                    
drop: 0.74235
Val loss: 1.7697 ACC: 0.77194                                                                                                    


Epoch 76:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6401/1.543 ACC: 0.90486                                                                                                     
drop: 0.74087
Val loss: 1.7672 ACC: 0.77395                                                                                                    


Epoch 77:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6406/1.5453 ACC: 0.90551                                                                                                    
drop: 0.74247
Val loss: 1.7668 ACC: 0.77595                                                                                                    


Epoch 78:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6416/1.8156 ACC: 0.9037                                                                                                     
drop: 0.73803
Val loss: 1.7662 ACC: 0.77635                                                                                                    
New Best Val loss: 1.7662                                                                                                    


Epoch 79:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6396/1.5431 ACC: 0.90486                                                                                                    
drop: 0.74087
Val loss: 1.7645 ACC: 0.77836                                                                                                    
New Best Val loss: 1.7645                                                                                                    


Epoch 80:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6404/1.543 ACC: 0.90491                                                                                                     
drop: 0.74099
Val loss: 1.7691 ACC: 0.77315                                                                                                    


Epoch 81:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6391/1.5464 ACC: 0.90611                                                                                                    
drop: 0.74396
Val loss: 1.7699 ACC: 0.77114                                                                                                    


Epoch 82:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6413/1.8764 ACC: 0.90456                                                                                                    
drop: 0.74013
Val loss: 1.7715 ACC: 0.77275                                                                                                    


Epoch 83:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6397/1.5553 ACC: 0.90541                                                                                                    
drop: 0.74223
Val loss: 1.771 ACC: 0.77154                                                                                                     


Epoch 84:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6393/1.877 ACC: 0.90692                                                                                                     
drop: 0.74594
Val loss: 1.7722 ACC: 0.77154                                                                                                    


Epoch 85:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6413/1.8708 ACC: 0.9041                                                                                                     
drop: 0.73902
Val loss: 1.7764 ACC: 0.76593                                                                                                    


Epoch 86:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6396/1.543 ACC: 0.90571                                                                                                     
drop: 0.74297
Val loss: 1.7752 ACC: 0.76713                                                                                                    


Epoch 87:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6388/1.8747 ACC: 0.90662                                                                                                    
drop: 0.74519
Val loss: 1.7745 ACC: 0.76713                                                                                                    


Epoch 88:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6386/1.543 ACC: 0.90521                                                                                                     
drop: 0.74173
Val loss: 1.773 ACC: 0.76994                                                                                                     


Epoch 89:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6378/1.5431 ACC: 0.90737                                                                                                    
drop: 0.74705
Val loss: 1.7719 ACC: 0.77154                                                                                                    


Epoch 90:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6406/1.8764 ACC: 0.90471                                                                                                    
drop: 0.7405
Val loss: 1.7713 ACC: 0.77034                                                                                                    
Epoch    90: reducing learning rate of group 0 to 5.7000e-03.


Epoch 91:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6379/1.543 ACC: 0.90682                                                                                                     
drop: 0.74569
Val loss: 1.7715 ACC: 0.77114                                                                                                    


Epoch 92:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6385/1.6764 ACC: 0.90712                                                                                                    
drop: 0.74643
Val loss: 1.7689 ACC: 0.77154                                                                                                    


Epoch 93:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6388/1.5645 ACC: 0.90601                                                                                                    
drop: 0.74371
Val loss: 1.77 ACC: 0.77275                                                                                                      


Epoch 94:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6388/1.5431 ACC: 0.90566                                                                                                    
drop: 0.74284
Val loss: 1.7715 ACC: 0.77114                                                                                                    


Epoch 95:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6396/1.8764 ACC: 0.90621                                                                                                    
drop: 0.7442
Val loss: 1.7695 ACC: 0.77435                                                                                                    


Epoch 96:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6376/1.5431 ACC: 0.90677                                                                                                    
drop: 0.74557
Val loss: 1.7704 ACC: 0.77194                                                                                                    


Epoch 97:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6388/1.8839 ACC: 0.90697                                                                                                    
drop: 0.74606
Val loss: 1.7695 ACC: 0.77315                                                                                                    


Epoch 98:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6395/1.8077 ACC: 0.90616                                                                                                    
drop: 0.74408
Val loss: 1.7685 ACC: 0.77355                                                                                                    


Epoch 99:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6371/1.5431 ACC: 0.90807                                                                                                    
drop: 0.74879
Val loss: 1.7689 ACC: 0.77515                                                                                                    


Epoch 100:   0%|          | 0/22402 [00:00<?, ?it/s]

Train loss: 1.6373/1.9691 ACC: 0.90842                                                                                                    
drop: 0.74966
Val loss: 1.7704 ACC: 0.77355                                                                                                    
Best Val loss: 1.7645                                                                                                    


In [56]:
ab = copy.deepcopy(best_model).to(device)
loss_total = 0
correct_t = 0
total_t = 0
dl_test = DataLoader(list(zip(fold.X_test, y_test)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=2)
for i, (terms_idx_t, docs_offsets_t, y_t) in enumerate(dl_test):
    terms_idx_t    = terms_idx_t.to( device )
    docs_offsets_t = docs_offsets_t.to( device )
    y_t            = y_t.to( device )

    pred_docs_t,weigths,coweights = ab( terms_idx_t, docs_offsets_t )
    pred_docs_t = F.softmax(pred_docs_t)

    y_pred_t    = pred_docs_t.argmax(axis=1)
    correct_t  += (y_pred_t == y_t).sum().item()
    total_t    += len(y_t)
    loss_total += loss_func_cel(pred_docs_t, y_t)

print(f'Test loss: {loss_total.item()/(i+1):.5} ACC: {correct_t/total_t:.5}', end=f"{' '*100}\r")

Test loss: 1.7705 ACC: 0.77475                                                                                                    

In [73]:
(y_pred_t == y_t).sum().item(), len(y_t), (y_pred_t == y_t).sum().item()/len(y_t)

(26, 63, 0.4126984126984127)

In [57]:
weigths[:,:,0].shape, coweights.shape, pred_docs_t.shape

(torch.Size([63, 100]), torch.Size([63, 100, 100]), torch.Size([63, 11]))

In [58]:
weigths[:,:,0]

tensor([[0.1169, 0.1110, 0.0961,  ..., 0.0000, 0.0000, 0.0000],
        [0.0085, 0.0114, 0.0101,  ..., 0.0115, 0.0092, 0.0113],
        [0.1655, 0.1046, 0.1434,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0990, 0.1927, 0.1301,  ..., 0.0000, 0.0000, 0.0000],
        [0.2272, 0.1201, 0.3264,  ..., 0.0000, 0.0000, 0.0000],
        [0.2211, 0.1816, 0.2994,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<SelectBackward>)

In [66]:
coweights

tensor([[[1.0000e+00, 9.6085e-01, 1.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 8.1252e-01, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 9.9999e-01, 1.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 1.0000e+00, 9.9994e-01,  ..., 1.0000e+00,
          0.0000e+00, 1.0000e+00],
         [0.0000e+00, 0.0000e+00, 1.0000e+00,  ..., 9.0235e-01,
          0.0000e+00, 1.0000e+00],
         ...,
         [0.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          9.995

In [23]:
"""
acm ####################################################################################
Train loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    
Val loss: 1.7718 ACC: 0.77475                                              ou eu ou a ca                                                      
New Best Val loss: 1.7718                                                                                                    
Test loss: 1.7678 ACC: 0.78557  79.92

Train loss: 1.7209/1.6095 ACC: 0.82338                                                                                                    
Val loss: 1.7595 ACC: 0.78236                                                                                                    
New Best Val loss: 1.7595                                                                                             
Test loss: 1.7585 ACC: 0.78557                                                                                                    

20ng ####################################################################################
Train loss: 2.0907/2.0787 ACC: 0.98845                                                                                                    
Val loss: 2.1869 ACC: 0.90803                                                                                                    
New Best Val loss: 2.1869                                                                                                    
Test loss: 2.178 ACC: 0.91068   92.65

Train loss: 2.1603/2.1249 ACC: 0.92511                                                                                                    
Val loss: 2.1842 ACC: 0.90645                                                                                                    
drop: 0.8436
New Best Val loss: 2.1842                                                                                                   
Test loss: 2.1705 ACC: 0.91226                                                                                                    

reut ####################################################################################
Train loss: 3.7735/3.5191 ACC: 0.74734                                                                                                    
Val loss: 3.8554 ACC: 0.6763                                                                                                    
New Best Val loss: 3.8554                                                                                                    
Test loss: 3.8493 ACC: 0.6837  72.67

Train loss: 3.7443/3.8521 ACC: 0.77727                                                                                                    
Val loss: 3.808 ACC: 0.71037                                                                                                     
New Best Val loss: 3.808
Test loss: 3.8461 ACC: 0.70444                                                                                                    

webkb ####################################################################################
Train loss: 1.2228/1.2037 ACC: 0.9504                                                                                                     
Val loss: 1.3787 ACC: 0.80316                                                                                                    
New Best Val loss: 1.3787                                                                                                    
Test loss: 1.3857 ACC: 0.78858   81.53

Epoch: 45
Train loss: 1.2392/1.2909 ACC: 0.9295                                                                                                     
Val loss: 1.3838 ACC: 0.78736                                                                                                    
New Best Val loss: 1.3838
Test loss: 1.3814 ACC: 0.78372                                                                                                    

"""

'\nacm ####################################################################################\nTrain loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    \nVal loss: 1.7718 ACC: 0.77475                                                                                                    \nNew Best Val loss: 1.7718                                                                                                    \nTest loss: 1.7678 ACC: 0.78557  79.92\n\nTrain loss: 1.7209/1.6095 ACC: 0.82338                                                                                                    \nVal loss: 1.7595 ACC: 0.78236                                                                                                    \nNew Best Val loss: 1.7595                                                                                             \nTest loss: 1.7585 ACC: 0.78557                                               

In [24]:
class AttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5):
        super(AttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.mask       = Mask()
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.dt_dir_map = nn.Linear(hiddens, hiddens)
        self.drop       = nn.Dropout(drop)
        self.ma_term    = nn.MultiheadAttention(hiddens, 1)
        self.fc         = nn.Linear(hiddens, nclass)
        self.initrange  = initrange 
        self.init_weights()
    def forward(self, terms_idx, docs_offsets):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h      = self.dt_emb( x_packed )
        dt_h      = self.drop(dt_h)
        dir_dt_h  = self.dt_dir_map( dt_h )

        weights = torch.bmm( dt_h, dir_dt_h.transpose( 1, 2 ) )
        weights = self.mask(weights)
        
        weights_disc = (weights * pad_mask)
        weights_disc = weights_disc.sum(axis=1)
        weights_disc = F.softmax(weights_disc, dim=1)
        weights_disc = weights_disc.view( *weights_disc.shape, 1 )
        
        attn_mask = weights != 0
        attn_mask = attn_mask.logical_and( pad_mask ).logical_not()
        
        dt_h     = dt_h.transpose(0,1)
        dir_dt_h = dir_dt_h.transpose(0,1)
        docs_att, weigths_att = self.ma_term( dt_h, dir_dt_h, dt_h,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )

        weigths_att = torch.where(torch.isnan(weigths_att), torch.zeros_like(weigths_att), weigths_att)
        weigths_att = (weigths_att * pad_mask)
        weigths_att = weigths_att.sum(axis=1)
        weigths_att = F.softmax(weigths_att, dim=1)
        weigths_att = weigths_att.view( *weigths_att.shape, 1 )
        
        weigths = weights_disc + weigths_att

        docs_att = docs_att.transpose(0,1)
        docs_att = torch.where(torch.isnan(docs_att), torch.zeros_like(docs_att), docs_att)
        
        docs_h = docs_att * weigths
        docs_h = docs_h.sum(axis=1)
        docs_h = docs_h / bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        docs_h = torch.where(torch.isnan(docs_h), torch.zeros_like(docs_h), docs_h)
        docs_h = self.drop(docs_h)
        return self.fc(docs_h)
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.dt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)
        self.ma_term.in_proj_weight.data.uniform_(-self.initrange, self.initrange)

In [25]:
y_pred_t      = pred_docs_t.argmax(axis=1)
correct_t     = (y_pred_t == y_t).sum().item()
total_t       = len(y_t)
correct_t/total_t

0.4126984126984127

In [26]:
docs_offsets_t

tensor([   0,   10,  342,  349,  358,  375,  379,  383,  387,  391,  395,  502,
         505,  509,  519,  522,  533,  575,  579,  590,  593,  597,  601,  607,
         610,  616,  620,  624,  634,  650,  654,  661,  665,  969,  973,  983,
         985,  990,  996, 1046, 1049, 1053, 1183, 1187, 1191, 1194, 1198, 1202,
        1206, 1211, 1215, 1219, 1231, 1249, 1422, 1426, 1430, 1433, 1437, 1450,
        1454, 1460, 1464], device='cuda:0')

In [27]:
terms_idx_t[:docs_offsets_t[batch_size]]

IndexError: index 64 is out of bounds for dimension 0 with size 63

In [None]:
batch_off_test  = docs_offsets_t[batch_size]
batch_tidx_test = terms_idx_t[:batch_off_test]
h_terms_test    = sc.tt_emb( batch_tidx_test )
dirh_terms_test = sc.tt_dir_map( h_terms_test )

W = torch.matmul( h_terms_test, dirh_terms_test.T )
W = F.leaky_relu( W, negative_slope=sc.mask.negative_slope)
W = F.sigmoid(W)
W

In [None]:
k = [ batch_tidx_test[ docs_offsets_t[i-1]:docs_offsets_t[i] ] for i in range(1, batch_size) ]
k.append( batch_tidx_test[ docs_offsets_t[batch_size-1]:docs_offsets_t[batch_size] ] )
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )
len(k)

In [None]:
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )

In [None]:
tt_emb.transpose(0,1)

In [None]:
a = [terms_idx, terms_idx]
torch.stack(a)

In [None]:
F.softmax(pred_docs_t).argmax(axis=1)

In [None]:
y_val

In [None]:
(y_t == F.softmax(pred_docs_t).argmax(axis=1)).sum().item()/y_t.shape[0]

In [None]:
terms_idx_t, docs_offsets_t

In [None]:
shifts = sc._get_shift_(docs_offsets_t, terms_idx_t.shape[0])

In [None]:
zipado = zip(docs_offsets_t, shifts)
#next(zipado)
start,size = next(zipado)

In [None]:
w = sc.tt_emb( terms_idx_t[start:start+size] )
w1 = sc.tt_dir_map( w )
w = torch.matmul( w, w1.T )
w = F.leaky_relu( w, negative_slope=sc.negative_slope)
w = F.sigmoid(w)
#w = F.tanh(w)
#w = F.relu(w)
#w = w.mean(axis=1)
#w = F.softmax(w)
w,w1

In [None]:
w.round().sum() / (w.shape[0]*w.shape[1])

In [None]:
inv_mapper = { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
terms_idx[start:start+size]

In [None]:
fold.X_val[0]

In [None]:
bla = w.mean(axis=1)
bla = F.softmax(bla)
#bla = bla/torch.clamp(bla.sum(), 0.0001)
bla

In [None]:
[ (i, tid.item(),inv_mapper[tid.item()], wei.item()) for i, (tid, wei) in enumerate(zip(terms_idx[start:start+size], bla)) ]

In [None]:
w

In [None]:
w.mean(axis=1)

In [None]:
w.shape

In [None]:
norm = nn.BatchNorm1d(num_features=1).to(device)

bla2 = norm(bla.view(-1, 1)).squeeze()
bla2 = F.sigmoid(bla2)
bla2

In [None]:
1

In [None]:
ma = nn.MultiheadAttention(300, 300).to(device)
ma

In [None]:
torch.__version__

In [None]:
a,b = w.shape
w_ = w.view(a,1,b)
w1_ = w1.view(a,1,b)

attn_output = ma(w_, w1_, w_, need_weights=False)

attn_output.view(a,b)
attn_output_weights.view(a,a)

In [None]:
attn_output.view(a,b).shape

In [None]:
attn_output_weights.view(a,a)

In [None]:
F.softmax(torch.Tensor([[0.0718, 0.0716, 0.0712, 0.0714, 0.0721, 0.0710, 0.0712, 0.0714, 0.0710,
         0.0719, 0.0711, 0.0712, 0.0718, 0.0714],
        [0.0722, 0.0709, 0.0710, 0.0709, 0.0728, 0.0723, 0.0709, 0.0709, 0.0719,
         0.0712, 0.0709, 0.0707, 0.0725, 0.0707],
        [0.0711, 0.0715, 0.0710, 0.0713, 0.0710, 0.0712, 0.0718, 0.0713, 0.0709,
         0.0725, 0.0716, 0.0719, 0.0710, 0.0719],
        [0.0721, 0.0717, 0.0714, 0.0713, 0.0717, 0.0714, 0.0711, 0.0710, 0.0713,
         0.0717, 0.0710, 0.0712, 0.0720, 0.0711],
        [0.0722, 0.0711, 0.0710, 0.0710, 0.0737, 0.0716, 0.0709, 0.0705, 0.0714,
         0.0719, 0.0707, 0.0707, 0.0731, 0.0702],
        [0.0714, 0.0716, 0.0708, 0.0711, 0.0718, 0.0713, 0.0712, 0.0717, 0.0712,
         0.0724, 0.0713, 0.0712, 0.0716, 0.0714],
        [0.0712, 0.0714, 0.0713, 0.0713, 0.0722, 0.0716, 0.0709, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0713, 0.0716, 0.0713],
        [0.0716, 0.0707, 0.0712, 0.0714, 0.0717, 0.0715, 0.0714, 0.0715, 0.0712,
         0.0721, 0.0711, 0.0714, 0.0717, 0.0715],
        [0.0717, 0.0713, 0.0708, 0.0710, 0.0725, 0.0717, 0.0714, 0.0709, 0.0712,
         0.0712, 0.0712, 0.0717, 0.0720, 0.0714],
        [0.0709, 0.0712, 0.0713, 0.0712, 0.0713, 0.0713, 0.0714, 0.0719, 0.0712,
         0.0724, 0.0715, 0.0715, 0.0717, 0.0713],
        [0.0715, 0.0716, 0.0712, 0.0708, 0.0725, 0.0717, 0.0712, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0706, 0.0718, 0.0712],
        [0.0726, 0.0711, 0.0713, 0.0706, 0.0737, 0.0721, 0.0709, 0.0707, 0.0714,
         0.0708, 0.0706, 0.0705, 0.0730, 0.0707],
        [0.0718, 0.0711, 0.0714, 0.0715, 0.0725, 0.0707, 0.0707, 0.0711, 0.0712,
         0.0728, 0.0711, 0.0713, 0.0719, 0.0709],
        [0.0712, 0.0716, 0.0713, 0.0712, 0.0717, 0.0707, 0.0708, 0.0721, 0.0709,
         0.0728, 0.0713, 0.0716, 0.0711, 0.0715]]).sum(axis=0))

In [None]:
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout1=0.1, dropout2=0.1, negative_slope=99,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.undirected_map = nn.Linear(hidden_l, hidden_l)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop1 = nn.Dropout(dropout1)
        self.drop2 = nn.Dropout(dropout2)
        
        self.norm = nn.BatchNorm1d(1)
        
        self.initrange = initrange
        self.nclass = nclass
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        
        terms_h1 = self.tt_emb(terms_idxs)
        terms_h1 = self.drop1(terms_h1)
        
        terms_h2 = self.undirected_map( terms_h1 )
        #terms_h2 = self.drop1( terms_h2 )
        for start,size in zip(docs_offsets, shifts):
            w  = terms_h1[start:start+size]
            w1 = terms_h2[start:start+size]
            w = torch.matmul( w, w1.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w-5.5)
            w = w.mean(axis=1)
            w = F.softmax(w)
            #w = w / torch.clamp(w.sum(), 0.0001)
            weights.append( w )
        
        weights = torch.cat(weights)
        #weights = self.norm(weights.view(-1, 1)).squeeze()
        #weights = F.sigmoid(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop2( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])