In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset

from tqdm.notebook import tqdm
from TGA.utils import preprocessor
import copy

from time import time
import numpy as np
from itertools import repeat
from collections import Counter
from segtok import tokenizer as tk

from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction import stop_words
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin

In [2]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [4]:
torch.__version__

'1.7.1'

In [5]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/reut/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 10627)

In [6]:
class Tokenizer(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, stopwords='remove', model='list', lan='english', verbose=False):
        super(Tokenizer, self).__init__()
        self.mindf = mindf
        self.le = LabelEncoder()
        self.verbose = verbose
        self.stopwords = stopwords
        self.stopwordsSet = set(stop_words.ENGLISH_STOP_WORDS)
        self.lan = lan
        self.model = model
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor).build_analyzer()
        #self.analyzer = tk.web_tokenizer
    
    def fit(self, X, y):
        self.N = len(X)
        self.le.fit( y )
        self.n_class = len(self.le.classes_)

        self.term_freqs = Counter()
        docs = map(self.analyzer, X)
        for doc_in_terms in tqdm(docs, total=self.N, disable=not self.verbose):
            doc_in_terms = list(map( self._filter_fit_, doc_in_terms ))
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {'<BLANK>': 0}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self._get_idx_(term) for term in self.term_freqs.keys() if self._isrel_(term) }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.node_mapper['<BLANK>'] = 0
        self.vocab_size = len(self.node_mapper)
        
        return self
    def _isrel_(self, term):
        if self.stopwords == 'remove' and term in self.stopwordsSet:
            return False
        # put here your filter_functions
        return True
    def _get_idx_(self, term):
        # put here your idx_set_functions
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            print('is stop', term)
            return self.node_mapper.setdefault('<STPW>', len(self.node_mapper))
        return self.node_mapper.setdefault(term, len(self.node_mapper))
    def _filter_transform_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        if term not in self.node_mapper:
            return '<UNK>'
        return term
    def _filter_fit_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        return term
    def _model_(self, doc):
        if self.model == 'set':
            return set(doc)
        return list(doc)
    def transform(self, X, verbose=None):
        verbose = verbose if verbose is not None else self.verbose
        n = len(X)
        doc_off = [0]
        terms_idx = []
        for i,doc_in_terms in tqdm(enumerate(map(self.analyzer, X)), total=n, disable=not verbose):
            doc_in_terms = filter( self._isrel_, doc_in_terms )
            doc_in_terms = map( self._filter_transform_, doc_in_terms )
            doc_in_terms = self._model_(doc_in_terms)
            doc_in_terms = [ self.node_mapper[tid] for tid in doc_in_terms ]
            if self.model == 'sorted':
                doc_in_terms = sorted(doc_in_terms)
            doc_off.append( len(doc_in_terms) )
            terms_idx.extend( doc_in_terms )
        return np.array( terms_idx ), np.array(doc_off)[:-1].cumsum()

In [7]:
tokenizer = Tokenizer(mindf=1, stopwords='keep', model='set', verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)

HBox(children=(FloatProgress(value=0.0, max=10627.0), HTML(value='')))




Tokenizer(mindf=1, model='set', stopwords='keep', verbose=True)

In [8]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )
y_test  = tokenizer.le.transform( fold.y_test )

In [9]:
tokenizer.transform(fold.X_val[:2])

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




(array([  568, 13239,    73,    75,  1550,     7,    83,   692,  2783,
          579,   300,  3358,  1914,  2871,   967,   100,  1813,   826,
           23,   102,   696,  1558,   108,  1715,   110,   111,   590,
          346,   115,  1975,   116,   594,   117,   118,  1661,  1164,
          837,  3559,   122,  6624,    40,  3984,   842,   600,  1273,
            9,    11, 17510,  5174,  1569,  2331,    14,  2301,   142,
         1812, 10220,  2303,  6054,  6783,  1123,   848,   150,  2393,
         1217,  1992,   155,  2895,   157,  3565,  1574,  4815,  2797,
         5635,  3345,  3191,   386,  4418,  3755,  1223,  2176,   172,
         1513,  5771, 25857, 18131,  2240,   183,   186,   404, 16547,
         2245,  4605,   869,  3460,    20,   929,  1447,   875,   316,
         1585,  6548,    33,   413,   207,  3390,  2314,  3362,    38,
         1473,  1331,   218,   219,  5648,  1333,  3273,  2258,     1,
         2068,   231,   533,  1242,  4365,   661,   302,   664,   240,
      

In [10]:
tokenizer.node_mapper['<BLANK>'], fold.X_val[0]

(0,
 '\r cocoa consumers narrow gap on buffer stock issue\r london march 16 representatives of cocoa consuming\r countries at an international cocoa organization icco council\r meeting here have edged closer to a unified stance on buffer\r stock rules delegates said \r while consumers do not yet have a common position an\r observer said after a consumer meeting they are much more\r fluid and the tone is positive \r european community consumers were split on the question of\r how the cocoa buffer stock should be operated when the icco met\r in january to put the new international cocoa agreement into\r effect delegates said \r at the january meeting france sided with producers on how\r the buffer stock should operate delegates said that meeting\r ended without agreement on new buffer stock rules \r the ec commission met in brussels on friday to see whether\r the 12 ec cocoa consuming nations could narrow their\r differences at this month s meeting \r the commissioners came away from the

In [11]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [12]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=2.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
        self.sig = nn.Sigmoid()
    def forward(self, h):
        w = F.leaky_relu( h, negative_slope=self.negative_slope)
        w = self.sig(w-self.kappa)
        return w

In [13]:
help(torch.embedding_bag)

Help on built-in function embedding_bag:

embedding_bag(...)



In [14]:
class SimpleAttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5, negative_slope=99.):
        super(SimpleAttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.tt_emb     = nn.Embedding(vocab_size, hiddens)
        self.tt_dir_map = nn.Linear(hiddens, hiddens)
        self.fc         = nn.Linear(hiddens, nclass)
        self.initrange  = initrange 
        self.negative_slope = negative_slope
        self.drop       = nn.Dropout(drop)
        self.self.drop_ = drop
        self.sig = nn.Sigmoid()
        self.init_weights()
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h     = self.dt_emb( x_packed )
        dt_h     = F.dropout( dt_h, p=self.drop_, training=self.training )
        #dt_h     = self.drop(dt_h)
        
        tt_h     = self.tt_emb( x_packed )
        tt_h     = F.dropout( tt_h, p=self.drop_, training=self.training )
        #tt_h     = self.drop(tt_h)
        dir_tt_h = self.tt_dir_map( tt_h )

        weights = torch.bmm( tt_h, dir_tt_h.transpose( 1, 2 ) )
        weights = F.leaky_relu( weights, negative_slope=self.negative_slope)
        
        weights[pad_mask.logical_not()] = float('-inf') # Set the 3D-pad mask values to -inf (=0 in softmax)
        weights = F.sigmoid(weights)
        #weights = F.softmax(weights, dim=2) # Normalize the neighbors weights
        #weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights) #replace nan to zero
        weights = weights.sum(axis=2) / doc_sizes
        weights[bx_packed] = float('-inf') # Set the 2D-pad mask values to -inf  (=0 in softmax)
        weights = F.softmax(weights, dim=1)
        #weights = F.sigmoid(weights)
        weights = weights.view( *weights.shape, 1 )
        
        docs_h = dt_h * weights
        docs_h = docs_h.sum(axis=1)
        docs_h = F.dropout( docs_h, p=self.drop_, training=self.training )
        #docs_h = self.drop(docs_h)
        return self.fc(docs_h), weights
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)

In [20]:
nepochs = 1000
max_epochs = 50
drop=0.95
device = torch.device('cuda:0')
batch_size = 32

In [21]:
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fa4d8b86270>

In [22]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
ab = SimpleAttentionBag(tokenizer.vocab_size, 300, tokenizer.n_class, drop=drop).to( device )

optimizer = optim.AdamW( ab.parameters(), lr=6e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.95,
                                                       patience=10, verbose=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=.98, verbose=True)

In [23]:
best = 99999.
counter = 1
dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=4)
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=4)
    total_loss  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        ab.train()
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs,_ = ab( terms_idx, docs_offsets)
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            
            toprint  = f"Train loss: {total_loss/(i+1):.5}/{loss.item():.5} "
            toprint += f'ACC: {correct/total:.5}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            del pred_docs, loss
            del terms_idx, docs_offsets, y
            del y_pred
            
        total = 0
        correct  = 0
        ab.eval()
        with torch.no_grad():
            loss_total = 0.
            print()
            for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
                terms_idx    = terms_idx.to( device )
                docs_offsets = docs_offsets.to( device )
                y            = y.to( device )

                pred_docs,weights = ab( terms_idx, docs_offsets )
                pred_docs   = F.softmax(pred_docs)

                y_pred      = pred_docs.argmax(axis=1)
                correct    += (y_pred == y).sum().item()
                total      += len(y)
                loss2       = loss_func_cel(pred_docs, y)
                loss_total += loss2

                print(f'Val loss: {loss_total.item()/(i+1):.5} ACC: {correct/total:.5}', end=f"{' '*100}\r")
   
                pbar.update( len(y) )

            del terms_idx, docs_offsets, y
            del y_pred
            scheduler.step(loss_total)
            print()
            #scheduler.step()

            if best-(loss_total/(i+1)) > 0.0001 :
                best = (loss_total/(i+1)).item()
                counter = 1
                print()
                print(f'New Best Val loss: {best:.5}', end=f"{' '*100}\n")
                best_model = copy.deepcopy(ab).to('cpu')
            elif counter > max_epochs:
                print()
                print(f'Best Val loss: {best:.5}', end=f"{' '*100}\n")
                break
            else:
                counter += 1
            del pred_docs, loss2

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.4438/4.5093 ACC: 0.08102                                                                                                      
Val loss: 4.2002 ACC: 0.32519                                                                                                    

New Best Val loss: 4.2002                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.2832/4.2443 ACC: 0.24795                                                                                                    
Val loss: 4.1204 ACC: 0.42444                                                                                                    

New Best Val loss: 4.1204                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.2123/4.1839 ACC: 0.31411                                                                                                    
Val loss: 4.085 ACC: 0.44296                                                                                                                                                                                                         

New Best Val loss: 4.085                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1823/3.9524 ACC: 0.34092                                                                                                    
Val loss: 4.0717 ACC: 0.44889                                                                                                    

New Best Val loss: 4.0717                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1717/4.2971 ACC: 0.35099                                                                                                    
Val loss: 4.0711 ACC: 0.44593                                                                                                                                                                                                        

New Best Val loss: 4.0711                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1554/4.514 ACC: 0.36652                                                                                                     
Val loss: 4.0694 ACC: 0.44593                                                                                                    

New Best Val loss: 4.0694                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1417/3.8801 ACC: 0.37828                                                                                                    
Val loss: 4.064 ACC: 0.45407                                                                                                     

New Best Val loss: 4.064                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1402/4.1465 ACC: 0.3812                                                                                                     
Val loss: 4.06 ACC: 0.4563                                                                                                       

New Best Val loss: 4.06                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=11977.0, style=ProgressStyle(description_wi…

Train loss: 4.1312/3.7701 ACC: 0.38844                                                                                                    
Val loss: 4.0554 ACC: 0.46296                                                                                                    

New Best Val loss: 4.0554                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1266/4.1842 ACC: 0.39343                                                                                                    
Val loss: 4.0502 ACC: 0.47037                                                                                                    

New Best Val loss: 4.0502                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1233/3.8525 ACC: 0.39597                                                                                                    
Val loss: 4.0483 ACC: 0.47111                                                                                                    

New Best Val loss: 4.0483                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.118/4.5186 ACC: 0.40303                                                                                                     
Val loss: 4.0466 ACC: 0.47185                                                                                                    

New Best Val loss: 4.0466                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1092/3.8513 ACC: 0.40924                                                                                                    
Val loss: 4.0386 ACC: 0.4837                                                                                                     

New Best Val loss: 4.0386                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1125/4.1852 ACC: 0.40689                                                                                                    
Val loss: 4.0307 ACC: 0.48963                                                                                                    

New Best Val loss: 4.0307                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1061/3.8498 ACC: 0.41188                                                                                                    
Val loss: 4.0298 ACC: 0.48889                                                                                                    

New Best Val loss: 4.0298                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1026/3.852 ACC: 0.41583                                                                                                     
Val loss: 4.0309 ACC: 0.48963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1057/4.5186 ACC: 0.41545                                                                                                    
Val loss: 4.0296 ACC: 0.48815                                                                                                    

New Best Val loss: 4.0296                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0962/3.5187 ACC: 0.42166                                                                                                    
Val loss: 4.0273 ACC: 0.49259                                                                                                    

New Best Val loss: 4.0273                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.1017/4.2025 ACC: 0.41715                                                                                                    
Val loss: 4.0262 ACC: 0.49333                                                                                                    

New Best Val loss: 4.0262                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0973/4.1854 ACC: 0.42213                                                                                                    
Val loss: 4.0248 ACC: 0.49333                                                                                                    

New Best Val loss: 4.0248                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0962/3.9235 ACC: 0.42279                                                                                                    
Val loss: 4.024 ACC: 0.49333                                                                                                     

New Best Val loss: 4.024                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0934/3.852 ACC: 0.42543                                                                                                     
Val loss: 4.0227 ACC: 0.49556                                                                                                    

New Best Val loss: 4.0227                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0884/3.852 ACC: 0.43041                                                                                                     
Val loss: 4.023 ACC: 0.49481                                                                                                     



HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0912/4.1851 ACC: 0.42919                                                                                                    
Val loss: 4.0207 ACC: 0.49926                                                                                                    

New Best Val loss: 4.0207                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 25', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0854/3.8563 ACC: 0.4322                                                                                                     
Val loss: 4.018 ACC: 0.50296                                                                                                     

New Best Val loss: 4.018                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 26', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0878/3.8518 ACC: 0.42994                                                                                                    
Val loss: 4.015 ACC: 0.50444                                                                                                     

New Best Val loss: 4.015                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 27', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0874/4.1781 ACC: 0.43117                                                                                                    
Val loss: 4.0148 ACC: 0.50593                                                                                                    

New Best Val loss: 4.0148                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 28', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0849/3.5202 ACC: 0.4323                                                                                                     
Val loss: 4.0155 ACC: 0.50444                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 29', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0841/3.8751 ACC: 0.43484                                                                                                    
Val loss: 4.0147 ACC: 0.50519                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 30', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0833/4.1835 ACC: 0.43578                                                                                                    
Val loss: 4.014 ACC: 0.50519                                                                                                     

New Best Val loss: 4.014                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 31', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.083/4.1852 ACC: 0.43549                                                                                                     
Val loss: 4.0143 ACC: 0.50519                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 32', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0877/4.1828 ACC: 0.43192                                                                                                    
Val loss: 4.0127 ACC: 0.50667                                                                                                    

New Best Val loss: 4.0127                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 33', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0794/4.1844 ACC: 0.4402                                                                                                     
Val loss: 4.0133 ACC: 0.50519                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 34', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0768/3.8504 ACC: 0.44189                                                                                                    
Val loss: 4.0131 ACC: 0.50519                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 35', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0751/4.1851 ACC: 0.44434                                                                                                    
Val loss: 4.0125 ACC: 0.50593                                                                                                    

New Best Val loss: 4.0125                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 36', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.077/3.852 ACC: 0.44114                                                                                                      
Val loss: 4.0127 ACC: 0.50519                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 37', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0785/4.0756 ACC: 0.44133                                                                                                    
Val loss: 4.0108 ACC: 0.50815                                                                                                    

New Best Val loss: 4.0108                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 38', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.078/3.853 ACC: 0.44011                                                                                                      
Val loss: 4.0105 ACC: 0.50815                                                                                                    

New Best Val loss: 4.0105                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 39', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0743/4.5187 ACC: 0.44632                                                                                                    
Val loss: 4.0099 ACC: 0.50963                                                                                                    

New Best Val loss: 4.0099                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 40', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0715/4.1852 ACC: 0.44754                                                                                                    
Val loss: 4.0098 ACC: 0.50963                                                                                                    

New Best Val loss: 4.0098                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 41', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0778/4.5176 ACC: 0.44255                                                                                                    
Val loss: 4.0099 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 42', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0768/4.5176 ACC: 0.44312                                                                                                    
Val loss: 4.009 ACC: 0.50963                                                                                                     

New Best Val loss: 4.009                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 43', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0735/4.1854 ACC: 0.44603                                                                                                    
Val loss: 4.0091 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 44', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0681/3.8498 ACC: 0.45074                                                                                                    
Val loss: 4.0095 ACC: 0.50889                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 45', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0692/3.8519 ACC: 0.4481                                                                                                     
Val loss: 4.0086 ACC: 0.50963                                                                                                    

New Best Val loss: 4.0086                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 46', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0687/4.5161 ACC: 0.4514                                                                                                     
Val loss: 4.0086 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 47', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0678/4.4607 ACC: 0.45177                                                                                                    
Val loss: 4.0088 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 48', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0678/3.852 ACC: 0.45036                                                                                                     
Val loss: 4.0089 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 49', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0687/4.5187 ACC: 0.45149                                                                                                    
Val loss: 4.0093 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 50', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0699/4.5165 ACC: 0.45008                                                                                                    
Val loss: 4.0083 ACC: 0.50963                                                                                                    

New Best Val loss: 4.0083                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 51', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0692/3.85 ACC: 0.44942                                                                                                      
Val loss: 4.0094 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 52', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0651/4.5187 ACC: 0.45516                                                                                                    
Val loss: 4.0072 ACC: 0.51111                                                                                                    

New Best Val loss: 4.0072                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 53', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0693/3.8507 ACC: 0.44886                                                                                                    
Val loss: 4.0082 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 54', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0672/4.1847 ACC: 0.45243                                                                                                    
Val loss: 4.0077 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 55', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0641/3.5197 ACC: 0.45319                                                                                                    
Val loss: 4.0074 ACC: 0.50963                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 56', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0667/4.1854 ACC: 0.45206                                                                                                    
Val loss: 4.0055 ACC: 0.51333                                                                                                    

New Best Val loss: 4.0055                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 57', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.068/4.1837 ACC: 0.45149                                                                                                     
Val loss: 4.0046 ACC: 0.51407                                                                                                    

New Best Val loss: 4.0046                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 58', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0637/3.5187 ACC: 0.45356                                                                                                    
Val loss: 4.0047 ACC: 0.51407                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 59', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.062/4.1854 ACC: 0.45704                                                                                                     
Val loss: 4.0047 ACC: 0.51481                                                                                                    



HBox(children=(FloatProgress(value=0.0, description='Epoch 60', max=11977.0, style=ProgressStyle(description_w…

Train loss: 4.0677/4.0133 ACC: 0.45062                                                                                                    

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Train loss: 4.0673/4.0168 ACC: 0.45099                                                                                                    Train loss: 4.0681/4.1744 ACC: 0.45019                                                                                                    

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-23-3a7ec1ab9dd0>", line 23, in <module>
    loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local

TypeError: object of type 'NoneType' has no len()

In [19]:
ab = copy.deepcopy(best_model).to(device)
loss_total = 0
correct_t = 0
total_t = 0
dl_test = DataLoader(list(zip(fold.X_test, y_test)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=2)
for i, (terms_idx_t, docs_offsets_t, y_t) in enumerate(dl_test):
    terms_idx_t    = terms_idx_t.to( device )
    docs_offsets_t = docs_offsets_t.to( device )
    y_t            = y_t.to( device )

    pred_docs_t,weigths = ab( terms_idx_t, docs_offsets_t )
    pred_docs_t = F.softmax(pred_docs_t)

    y_pred_t    = pred_docs_t.argmax(axis=1)
    correct_t  += (y_pred_t == y_t).sum().item()
    total_t    += len(y_t)
    loss_total += loss_func_cel(pred_docs_t, y_t)

print(f'Test loss: {loss_total.item()/(i+1):.5} ACC: {correct_t/total_t:.5}', end=f"{' '*100}\r")

Test loss: 3.8242 ACC: 0.70963                                                                                                    

In [20]:
"""
acm ####################################################################################
Train loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    
Val loss: 1.7718 ACC: 0.77475                                                                                                    
New Best Val loss: 1.7718                                                                                                    
Test loss: 1.7678 ACC: 0.78557  79.92

Train loss: 1.6915/1.5888 ACC: 0.85292                                                                                                    
Val loss: 1.7657 ACC: 0.78076                                                                                                    
Adjusting learning rate of group 0 to 5.5342e-03.
New Best Val loss: 1.7657
Test loss: 1.7627 ACC: 0.78196                                                                                                    


20ng ####################################################################################
Train loss: 2.0907/2.0787 ACC: 0.98845                                                                                                    
Val loss: 2.1869 ACC: 0.90803                                                                                                    
New Best Val loss: 2.1869                                                                                                    
Test loss: 2.178 ACC: 0.91068   92.65

reut ####################################################################################
Train loss: 3.7735/3.5191 ACC: 0.74734                                                                                                    
Val loss: 3.8554 ACC: 0.6763                                                                                                    
New Best Val loss: 3.8554                                                                                                    
Test loss: 3.8493 ACC: 0.6837  72.67

Train loss: 3.7489/3.8717 ACC: 0.77265                                                                                                    
Val loss: 3.8084 ACC: 0.71037                                                                                                    
Adjusting learning rate of group 0 to 5.1046e-03.
New Best Val loss: 3.8084 

webkb ####################################################################################
Train loss: 1.2228/1.2037 ACC: 0.9504                                                                                                     
Val loss: 1.3787 ACC: 0.80316                                                                                                    
New Best Val loss: 1.3787                                                                                                    
Test loss: 1.3857 ACC: 0.78858   81.53

"""

'\nacm ####################################################################################\nTrain loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    \nVal loss: 1.7718 ACC: 0.77475                                                                                                    \nNew Best Val loss: 1.7718                                                                                                    \nTest loss: 1.7678 ACC: 0.78557  79.92\n\nTrain loss: 1.6915/1.5888 ACC: 0.85292                                                                                                    \nVal loss: 1.7657 ACC: 0.78076                                                                                                    \nAdjusting learning rate of group 0 to 5.5342e-03.\nNew Best Val loss: 1.7657\nTest loss: 1.7627 ACC: 0.78196                                                                                         

In [21]:
class AttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5):
        super(AttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.mask       = Mask()
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.dt_dir_map = nn.Linear(hiddens, hiddens)
        self.drop       = nn.Dropout(drop)
        self.ma_term    = nn.MultiheadAttention(hiddens, 1)
        self.fc         = nn.Linear(hiddens, nclass)
        self.initrange  = initrange 
        self.init_weights()
    def forward(self, terms_idx, docs_offsets):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h      = self.dt_emb( x_packed )
        dt_h      = self.drop(dt_h)
        dir_dt_h  = self.dt_dir_map( dt_h )

        weights = torch.bmm( dt_h, dir_dt_h.transpose( 1, 2 ) )
        weights = self.mask(weights)
        
        weights_disc = (weights * pad_mask)
        weights_disc = weights_disc.sum(axis=1)
        weights_disc = F.softmax(weights_disc, dim=1)
        weights_disc = weights_disc.view( *weights_disc.shape, 1 )
        
        attn_mask = weights != 0
        attn_mask = attn_mask.logical_and( pad_mask ).logical_not()
        
        dt_h     = dt_h.transpose(0,1)
        dir_dt_h = dir_dt_h.transpose(0,1)
        docs_att, weigths_att = self.ma_term( dt_h, dir_dt_h, dt_h,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )

        weigths_att = torch.where(torch.isnan(weigths_att), torch.zeros_like(weigths_att), weigths_att)
        weigths_att = (weigths_att * pad_mask)
        weigths_att = weigths_att.sum(axis=1)
        weigths_att = F.softmax(weigths_att, dim=1)
        weigths_att = weigths_att.view( *weigths_att.shape, 1 )
        
        weigths = weights_disc + weigths_att

        docs_att = docs_att.transpose(0,1)
        docs_att = torch.where(torch.isnan(docs_att), torch.zeros_like(docs_att), docs_att)
        
        docs_h = docs_att * weigths
        docs_h = docs_h.sum(axis=1)
        docs_h = docs_h / bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        docs_h = torch.where(torch.isnan(docs_h), torch.zeros_like(docs_h), docs_h)
        docs_h = self.drop(docs_h)
        return self.fc(docs_h)
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.dt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)
        self.ma_term.in_proj_weight.data.uniform_(-self.initrange, self.initrange)

In [22]:
y_pred_t      = pred_docs_t.argmax(axis=1)
correct_t     = (y_pred_t == y_t).sum().item()
total_t       = len(y_t)
correct_t/total_t

0.0

In [23]:
docs_offsets_t

tensor([  0, 331, 365, 522, 832, 974], device='cuda:0')

In [24]:
terms_idx_t[:docs_offsets_t[batch_size]]

IndexError: index 32 is out of bounds for dimension 0 with size 6

In [None]:
batch_off_test  = docs_offsets_t[batch_size]
batch_tidx_test = terms_idx_t[:batch_off_test]
h_terms_test    = sc.tt_emb( batch_tidx_test )
dirh_terms_test = sc.tt_dir_map( h_terms_test )

W = torch.matmul( h_terms_test, dirh_terms_test.T )
W = F.leaky_relu( W, negative_slope=sc.mask.negative_slope)
W = F.sigmoid(W)
W

In [None]:
k = [ batch_tidx_test[ docs_offsets_t[i-1]:docs_offsets_t[i] ] for i in range(1, batch_size) ]
k.append( batch_tidx_test[ docs_offsets_t[batch_size-1]:docs_offsets_t[batch_size] ] )
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )
len(k)

In [None]:
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )

In [None]:
tt_emb.transpose(0,1)

In [None]:
a = [terms_idx, terms_idx]
torch.stack(a)

In [None]:
F.softmax(pred_docs_t).argmax(axis=1)

In [None]:
y_val

In [None]:
(y_t == F.softmax(pred_docs_t).argmax(axis=1)).sum().item()/y_t.shape[0]

In [None]:
terms_idx_t, docs_offsets_t

In [None]:
shifts = sc._get_shift_(docs_offsets_t, terms_idx_t.shape[0])

In [None]:
zipado = zip(docs_offsets_t, shifts)
#next(zipado)
start,size = next(zipado)

In [None]:
w = sc.tt_emb( terms_idx_t[start:start+size] )
w1 = sc.tt_dir_map( w )
w = torch.matmul( w, w1.T )
w = F.leaky_relu( w, negative_slope=sc.negative_slope)
w = F.sigmoid(w)
#w = F.tanh(w)
#w = F.relu(w)
#w = w.mean(axis=1)
#w = F.softmax(w)
w,w1

In [None]:
w.round().sum() / (w.shape[0]*w.shape[1])

In [None]:
inv_mapper = { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
terms_idx[start:start+size]

In [None]:
fold.X_val[0]

In [None]:
bla = w.mean(axis=1)
bla = F.softmax(bla)
#bla = bla/torch.clamp(bla.sum(), 0.0001)
bla

In [None]:
[ (i, tid.item(),inv_mapper[tid.item()], wei.item()) for i, (tid, wei) in enumerate(zip(terms_idx[start:start+size], bla)) ]

In [None]:
w

In [None]:
w.mean(axis=1)

In [None]:
w.shape

In [None]:
norm = nn.BatchNorm1d(num_features=1).to(device)

bla2 = norm(bla.view(-1, 1)).squeeze()
bla2 = F.sigmoid(bla2)
bla2

In [None]:
1

In [None]:
ma = nn.MultiheadAttention(300, 300).to(device)
ma

In [None]:
torch.__version__

In [None]:
a,b = w.shape
w_ = w.view(a,1,b)
w1_ = w1.view(a,1,b)

attn_output = ma(w_, w1_, w_, need_weights=False)

attn_output.view(a,b)
attn_output_weights.view(a,a)

In [None]:
attn_output.view(a,b).shape

In [None]:
attn_output_weights.view(a,a)

In [None]:
F.softmax(torch.Tensor([[0.0718, 0.0716, 0.0712, 0.0714, 0.0721, 0.0710, 0.0712, 0.0714, 0.0710,
         0.0719, 0.0711, 0.0712, 0.0718, 0.0714],
        [0.0722, 0.0709, 0.0710, 0.0709, 0.0728, 0.0723, 0.0709, 0.0709, 0.0719,
         0.0712, 0.0709, 0.0707, 0.0725, 0.0707],
        [0.0711, 0.0715, 0.0710, 0.0713, 0.0710, 0.0712, 0.0718, 0.0713, 0.0709,
         0.0725, 0.0716, 0.0719, 0.0710, 0.0719],
        [0.0721, 0.0717, 0.0714, 0.0713, 0.0717, 0.0714, 0.0711, 0.0710, 0.0713,
         0.0717, 0.0710, 0.0712, 0.0720, 0.0711],
        [0.0722, 0.0711, 0.0710, 0.0710, 0.0737, 0.0716, 0.0709, 0.0705, 0.0714,
         0.0719, 0.0707, 0.0707, 0.0731, 0.0702],
        [0.0714, 0.0716, 0.0708, 0.0711, 0.0718, 0.0713, 0.0712, 0.0717, 0.0712,
         0.0724, 0.0713, 0.0712, 0.0716, 0.0714],
        [0.0712, 0.0714, 0.0713, 0.0713, 0.0722, 0.0716, 0.0709, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0713, 0.0716, 0.0713],
        [0.0716, 0.0707, 0.0712, 0.0714, 0.0717, 0.0715, 0.0714, 0.0715, 0.0712,
         0.0721, 0.0711, 0.0714, 0.0717, 0.0715],
        [0.0717, 0.0713, 0.0708, 0.0710, 0.0725, 0.0717, 0.0714, 0.0709, 0.0712,
         0.0712, 0.0712, 0.0717, 0.0720, 0.0714],
        [0.0709, 0.0712, 0.0713, 0.0712, 0.0713, 0.0713, 0.0714, 0.0719, 0.0712,
         0.0724, 0.0715, 0.0715, 0.0717, 0.0713],
        [0.0715, 0.0716, 0.0712, 0.0708, 0.0725, 0.0717, 0.0712, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0706, 0.0718, 0.0712],
        [0.0726, 0.0711, 0.0713, 0.0706, 0.0737, 0.0721, 0.0709, 0.0707, 0.0714,
         0.0708, 0.0706, 0.0705, 0.0730, 0.0707],
        [0.0718, 0.0711, 0.0714, 0.0715, 0.0725, 0.0707, 0.0707, 0.0711, 0.0712,
         0.0728, 0.0711, 0.0713, 0.0719, 0.0709],
        [0.0712, 0.0716, 0.0713, 0.0712, 0.0717, 0.0707, 0.0708, 0.0721, 0.0709,
         0.0728, 0.0713, 0.0716, 0.0711, 0.0715]]).sum(axis=0))

In [None]:
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout1=0.1, dropout2=0.1, negative_slope=99,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.undirected_map = nn.Linear(hidden_l, hidden_l)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop1 = nn.Dropout(dropout1)
        self.drop2 = nn.Dropout(dropout2)
        
        self.norm = nn.BatchNorm1d(1)
        
        self.initrange = initrange
        self.nclass = nclass
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        
        terms_h1 = self.tt_emb(terms_idxs)
        terms_h1 = self.drop1(terms_h1)
        
        terms_h2 = self.undirected_map( terms_h1 )
        #terms_h2 = self.drop1( terms_h2 )
        for start,size in zip(docs_offsets, shifts):
            w  = terms_h1[start:start+size]
            w1 = terms_h2[start:start+size]
            w = torch.matmul( w, w1.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w-5.5)
            w = w.mean(axis=1)
            w = F.softmax(w)
            #w = w / torch.clamp(w.sum(), 0.0001)
            weights.append( w )
        
        weights = torch.cat(weights)
        #weights = self.norm(weights.view(-1, 1)).squeeze()
        #weights = F.sigmoid(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop2( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])