In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset, Tokenizer

from tqdm.notebook import tqdm
import copy

from time import time
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [2]:
import matplotlib.pyplot as plt

In [3]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [5]:
dataset = Dataset('/home/Documents/datasets/webkb/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 6553)

In [6]:
tokenizer = Tokenizer(mindf=1, stopwords='remove', model='sample', k=128, verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)
tokenizer.vocab_size, tokenizer.N

100%|██████████| 6553/6553 [00:02<00:00, 2725.80it/s]


(23063, 6553)

In [7]:
from multiprocessing import Pool

In [8]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )
y_test  = tokenizer.le.transform( fold.y_test )

In [9]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [10]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=2.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
        self.sig = nn.Sigmoid()
    def forward(self, h):
        w = F.leaky_relu( h, negative_slope=self.negative_slope)
        w = self.sig(w-self.kappa)
        return w

In [17]:
class SimpleAttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5, negative_slope=99.):
        super(SimpleAttentionBag, self).__init__()
        self.hiddens        = hiddens
        self.dt_emb         = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.tt_s_emb       = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.tt_t_emb       = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.fc             = nn.Linear(hiddens, nclass)
        self.initrange      = initrange 
        self.negative_slope = negative_slope
        self.drop           = nn.Dropout(drop)
        self.norm           = nn.BatchNorm1d(hiddens)
        self.drop_          = drop
        self.sig            = nn.Sigmoid()
        self.init_weights()
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        
        tt_h     = self.tt_s_emb( x_packed )
        tt_dir_h = self.tt_t_emb( x_packed )
        
        dt_h     = tt_h + tt_dir_h
        dt_h     = F.dropout( dt_h, p=self.drop_, training=self.training )
        
        tt_h = F.tanh(tt_h)
        tt_h = F.dropout( tt_h, p=self.drop_, training=self.training )
        
        tt_dir_h = F.tanh(tt_dir_h)
        tt_dir_h = F.dropout( tt_dir_h, p=self.drop_, training=self.training )
        
        co_weights = torch.bmm( tt_h, tt_dir_h.transpose( 1, 2 ) )
        co_weights = F.leaky_relu( co_weights, negative_slope=self.negative_slope)
        
        co_weights[pad_mask.logical_not()] = float('-inf') # Set the 3D-pad mask values to -inf (=0 in sigmoid)
        co_weights = F.sigmoid(co_weights)
        
        weights = co_weights.sum(axis=2) / doc_sizes
        weights[bx_packed] = float('-inf') # Set the 2D-pad mask values to -inf  (=0 in softmax)
        weights = F.softmax(weights, dim=1)
        weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
        weights = weights.view( *weights.shape, 1 )
        
        docs_h = dt_h * weights
        docs_h = docs_h.sum(axis=1)
        docs_h = F.dropout( docs_h, p=self.drop_, training=self.training )
        docs_h = self.fc(docs_h)
        return docs_h, weights, co_weights
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_s_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_t_emb.weight.data.uniform_(-self.initrange, self.initrange)

In [18]:
nepochs = 1000
max_epochs = 20
drop=0.3
max_drop=0.85
device = torch.device('cuda:0')
batch_size = 64
k = 500

In [19]:
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fbaa0171110>

In [20]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
ab = SimpleAttentionBag(tokenizer.vocab_size, 300, tokenizer.n_class, drop=drop).to( device )
#ab = AttentionBag(tokenizer.vocab_size, 300, tokenizer.n_class, drop=drop).to( device )
#ab = NotTooSimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout1=drop, dropout2=drop).to( device )
tokenizer.k = k
optimizer = optim.AdamW( ab.parameters(), lr=5e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.95,
                                                       patience=5, verbose=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=.98, verbose=True)

In [21]:
num_workers=16

In [22]:
best = 99999.
counter = 1
loss_val = 1.
eps = .9
dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                         shuffle=False, collate_fn=collate_train, num_workers=num_workers)
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=num_workers)
    loss_train  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        ab.train()
        tokenizer.model = 'sample'
        #tokenizer.model = 'sample'
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs,_,_ = ab( terms_idx, docs_offsets)
            pred_docs     = F.softmax(pred_docs)
            loss          = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_train += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            #ab.drop_ =  np.power((correct/total),loss_val)
            #ab.drop_ =  np.power((correct/total),4)
            ab.drop_ =  (correct/total)*max_drop
            
            toprint  = f"Train loss: {loss_train/(i+1):.5}/{loss.item():.5} "
            toprint += f'Drop: {ab.drop_:.5} '
            toprint += f'ACC: {correct/total:.5} '
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            del pred_docs, loss
            del terms_idx, docs_offsets, y
            del y_pred
        loss_train = loss_train/(i+1)
        print()
        #print(ab.drop_)
        total = 0
        correct  = 0
        ab.eval()
        tokenizer.model = 'topk'
        with torch.no_grad():
            loss_val = 0.
            for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
                terms_idx    = terms_idx.to( device )
                docs_offsets = docs_offsets.to( device )
                y            = y.to( device )

                pred_docs, weights, co_weights = ab( terms_idx, docs_offsets )
                pred_docs   = F.softmax(pred_docs)

                y_pred      = pred_docs.argmax(axis=1)
                correct    += (y_pred == y).sum().item()
                total      += len(y)
                loss2       = loss_func_cel(pred_docs, y)
                loss_val   += loss2

                print(f'Val loss: {loss_val.item()/(i+1):.5} ACC: {correct/total:.5}', end=f"{' '*100}\r")
   
                pbar.update( len(y) )
            print()

            del terms_idx, docs_offsets, y
            del y_pred
            
            loss_val   = (loss_val/(i+1)).cpu()
            scheduler.step(loss_val)

            if best-loss_val > 0.0001 :
                best = loss_val.item()
                counter = 1
                print(f'New Best Val loss: {best:.5}', end=f"{' '*100}\n")
                best_model = copy.deepcopy(ab).to('cpu')
            elif counter > max_epochs:
                print(f'Best Val loss: {best:.5}', end=f"{' '*100}\n")
                break
            else:
                counter += 1
            del pred_docs, loss2

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.6949/1.5672 Drop: 0.45256 ACC: 0.53243                                                                                                       
Val loss: 1.5495 ACC: 0.64399                                                                                                    
New Best Val loss: 1.5495                                                                                                    


Epoch 2:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.4933/1.5329 Drop: 0.58889 ACC: 0.69281                                                                                                     
Val loss: 1.4903 ACC: 0.68287                                                                                                    
New Best Val loss: 1.4903                                                                                                    


Epoch 3:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.4164/1.5434 Drop: 0.65452 ACC: 0.77003                                                                                                     
Val loss: 1.4378 ACC: 0.74119                                                                                                    
New Best Val loss: 1.4378                                                                                                    


Epoch 4:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3774/1.4725 Drop: 0.68371 ACC: 0.80436                                                                                                     
Val loss: 1.4287 ACC: 0.74484                                                                                                    
New Best Val loss: 1.4287                                                                                                    


Epoch 5:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3573/1.407 Drop: 0.69759 ACC: 0.82069                                                                                                      
Val loss: 1.421 ACC: 0.75456                                                                                                     
New Best Val loss: 1.421                                                                                                    


Epoch 6:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3467/1.2465 Drop: 0.70265 ACC: 0.82664                                                                                                     
Val loss: 1.4169 ACC: 0.76306                                                                                                    
New Best Val loss: 1.4169                                                                                                    


Epoch 7:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3397/1.3592 Drop: 0.70939 ACC: 0.83458                                                                                                     
Val loss: 1.4154 ACC: 0.76063                                                                                                    
New Best Val loss: 1.4154                                                                                                    


Epoch 8:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3307/1.4213 Drop: 0.71562 ACC: 0.8419                                                                                                      
Val loss: 1.4144 ACC: 0.76306                                                                                                    
New Best Val loss: 1.4144                                                                                                    


Epoch 9:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3207/1.2901 Drop: 0.72327 ACC: 0.85091                                                                                                     
Val loss: 1.4128 ACC: 0.76914                                                                                                    
New Best Val loss: 1.4128                                                                                                    


Epoch 10:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3117/1.2328 Drop: 0.73702 ACC: 0.86708                                                                                                     
Val loss: 1.4079 ACC: 0.77035                                                                                                    
New Best Val loss: 1.4079                                                                                                    


Epoch 11:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.3003/1.3898 Drop: 0.74882 ACC: 0.88097                                                                                                     
Val loss: 1.4019 ACC: 0.77521                                                                                                    
New Best Val loss: 1.4019                                                                                                    


Epoch 12:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2912/1.2331 Drop: 0.75064 ACC: 0.88311                                                                                                     
Val loss: 1.4005 ACC: 0.77886                                                                                                    
New Best Val loss: 1.4005                                                                                                    


Epoch 13:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2862/1.4008 Drop: 0.75622 ACC: 0.88967                                                                                                     
Val loss: 1.3973 ACC: 0.78007                                                                                                    
New Best Val loss: 1.3973                                                                                                    


Epoch 14:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2831/1.3287 Drop: 0.75907 ACC: 0.89303                                                                                                     
Val loss: 1.3937 ACC: 0.78979                                                                                                    
New Best Val loss: 1.3937                                                                                                    


Epoch 15:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2779/1.2298 Drop: 0.76076 ACC: 0.89501                                                                                                     
Val loss: 1.3916 ACC: 0.78372                                                                                                    
New Best Val loss: 1.3916                                                                                                    


Epoch 16:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.275/1.2149 Drop: 0.76426 ACC: 0.89913                                                                                                      
Val loss: 1.3917 ACC: 0.77886                                                                                                    


Epoch 17:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2722/1.2535 Drop: 0.76478 ACC: 0.89974                                                                                                     
Val loss: 1.3891 ACC: 0.78736                                                                                                    
New Best Val loss: 1.3891                                                                                                    


Epoch 18:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2685/1.3183 Drop: 0.76841 ACC: 0.90401                                                                                                     
Val loss: 1.3902 ACC: 0.78615                                                                                                    


Epoch 19:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2674/1.3256 Drop: 0.77023 ACC: 0.90615                                                                                                     
Val loss: 1.3893 ACC: 0.7825                                                                                                     


Epoch 20:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2674/1.1834 Drop: 0.76698 ACC: 0.90233                                                                                                     
Val loss: 1.3892 ACC: 0.7825                                                                                                     


Epoch 21:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2642/1.2086 Drop: 0.76984 ACC: 0.90569                                                                                                     
Val loss: 1.3885 ACC: 0.78736                                                                                                    
New Best Val loss: 1.3885                                                                                                    


Epoch 22:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2599/1.3184 Drop: 0.77269 ACC: 0.90905                                                                                                     
Val loss: 1.3913 ACC: 0.78372                                                                                                    


Epoch 23:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2622/1.3707 Drop: 0.77217 ACC: 0.90844                                                                                                     
Val loss: 1.3898 ACC: 0.78493                                                                                                    


Epoch 24:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2566/1.2144 Drop: 0.77606 ACC: 0.91302                                                                                                     
Val loss: 1.394 ACC: 0.77886                                                                                                     


Epoch 25:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2582/1.3145 Drop: 0.77386 ACC: 0.91042                                                                                                     
Val loss: 1.3932 ACC: 0.77521                                                                                                    


Epoch 26:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.256/1.216 Drop: 0.77736 ACC: 0.91454                                                                                                       
Val loss: 1.3935 ACC: 0.78007                                                                                                    


Epoch 27:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2528/1.2851 Drop: 0.77827 ACC: 0.91561                                                                                                     
Val loss: 1.3928 ACC: 0.77886                                                                                                    
Epoch    27: reducing learning rate of group 0 to 4.7500e-03.


Epoch 28:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2573/1.249 Drop: 0.77529 ACC: 0.9121                                                                                                       
Val loss: 1.3924 ACC: 0.77886                                                                                                    


Epoch 29:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.251/1.2321 Drop: 0.78099 ACC: 0.91882                                                                                                      
Val loss: 1.3855 ACC: 0.78372                                                                                                    
New Best Val loss: 1.3855                                                                                                    


Epoch 30:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2467/1.3006 Drop: 0.78618 ACC: 0.92492                                                                                                     
Val loss: 1.3812 ACC: 0.78979                                                                                                    
New Best Val loss: 1.3812                                                                                                    


Epoch 31:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2478/1.1784 Drop: 0.78488 ACC: 0.92339                                                                                                     
Val loss: 1.3772 ACC: 0.79344                                                                                                    
New Best Val loss: 1.3772                                                                                                    


Epoch 32:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2474/1.2455 Drop: 0.78618 ACC: 0.92492                                                                                                     
Val loss: 1.3758 ACC: 0.79587                                                                                                    
New Best Val loss: 1.3758                                                                                                    


Epoch 33:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2466/1.262 Drop: 0.78437 ACC: 0.92278                                                                                                      
Val loss: 1.3749 ACC: 0.80194                                                                                                    
New Best Val loss: 1.3749                                                                                                    


Epoch 34:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2428/1.3916 Drop: 0.78917 ACC: 0.92843                                                                                                     
Val loss: 1.3764 ACC: 0.79465                                                                                                    


Epoch 35:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2392/1.2318 Drop: 0.79098 ACC: 0.93057                                                                                                     
Val loss: 1.3771 ACC: 0.79708                                                                                                    


Epoch 36:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2415/1.188 Drop: 0.78891 ACC: 0.92812                                                                                                      
Val loss: 1.3775 ACC: 0.78493                                                                                                    


Epoch 37:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.238/1.2286 Drop: 0.79215 ACC: 0.93194                                                                                                      
Val loss: 1.3758 ACC: 0.79465                                                                                                    


Epoch 38:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2354/1.25 Drop: 0.79383 ACC: 0.93392                                                                                                       
Val loss: 1.3773 ACC: 0.79465                                                                                                    


Epoch 39:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2369/1.2028 Drop: 0.7915 ACC: 0.93118                                                                                                      
Val loss: 1.3775 ACC: 0.78979                                                                                                    
Epoch    39: reducing learning rate of group 0 to 4.5125e-03.


Epoch 40:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2387/1.1916 Drop: 0.78955 ACC: 0.92889                                                                                                     
Val loss: 1.3766 ACC: 0.79344                                                                                                    


Epoch 41:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2388/1.1774 Drop: 0.7902 ACC: 0.92965                                                                                                      
Val loss: 1.3751 ACC: 0.79344                                                                                                    


Epoch 42:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2352/1.2067 Drop: 0.79358 ACC: 0.93362                                                                                                     
Val loss: 1.3762 ACC: 0.79101                                                                                                    


Epoch 43:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2359/1.2078 Drop: 0.79241 ACC: 0.93224                                                                                                     
Val loss: 1.376 ACC: 0.79222                                                                                                     


Epoch 44:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2327/1.214 Drop: 0.79526 ACC: 0.9356                                                                                                       
Val loss: 1.3766 ACC: 0.79101                                                                                                    


Epoch 45:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2331/1.2079 Drop: 0.79461 ACC: 0.93484                                                                                                     
Val loss: 1.3771 ACC: 0.78615                                                                                                    
Epoch    45: reducing learning rate of group 0 to 4.2869e-03.


Epoch 46:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2317/1.2703 Drop: 0.79604 ACC: 0.93652                                                                                                     
Val loss: 1.3757 ACC: 0.78979                                                                                                    


Epoch 47:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2314/1.2069 Drop: 0.79604 ACC: 0.93652                                                                                                     
Val loss: 1.3759 ACC: 0.79101                                                                                                    


Epoch 48:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2315/1.279 Drop: 0.79669 ACC: 0.93728                                                                                                      
Val loss: 1.3773 ACC: 0.78979                                                                                                    


Epoch 49:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2312/1.2215 Drop: 0.79604 ACC: 0.93652                                                                                                     
Val loss: 1.3773 ACC: 0.78979                                                                                                    


Epoch 50:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2302/1.2089 Drop: 0.79721 ACC: 0.93789                                                                                                     
Val loss: 1.3755 ACC: 0.79465                                                                                                    


Epoch 51:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2295/1.1692 Drop: 0.79708 ACC: 0.93774                                                                                                     
Val loss: 1.3745 ACC: 0.79222                                                                                                    
New Best Val loss: 1.3745                                                                                                    


Epoch 52:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2295/1.2423 Drop: 0.79682 ACC: 0.93743                                                                                                     
Val loss: 1.3743 ACC: 0.79587                                                                                                    
New Best Val loss: 1.3743                                                                                                    


Epoch 53:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2312/1.2307 Drop: 0.79617 ACC: 0.93667                                                                                                     
Val loss: 1.3739 ACC: 0.79465                                                                                                    
New Best Val loss: 1.3739                                                                                                    


Epoch 54:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2306/1.2304 Drop: 0.79747 ACC: 0.9382                                                                                                      
Val loss: 1.3734 ACC: 0.79465                                                                                                    
New Best Val loss: 1.3734                                                                                                    


Epoch 55:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2273/1.185 Drop: 0.79967 ACC: 0.94079                                                                                                      
Val loss: 1.3733 ACC: 0.79708                                                                                                    


Epoch 56:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2292/1.205 Drop: 0.79773 ACC: 0.9385                                                                                                       
Val loss: 1.3736 ACC: 0.79222                                                                                                    


Epoch 57:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.226/1.3001 Drop: 0.80097 ACC: 0.94232                                                                                                      
Val loss: 1.3733 ACC: 0.80073                                                                                                    
New Best Val loss: 1.3733                                                                                                    


Epoch 58:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2267/1.2673 Drop: 0.79967 ACC: 0.94079                                                                                                     
Val loss: 1.3741 ACC: 0.79344                                                                                                    


Epoch 59:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2292/1.2482 Drop: 0.79825 ACC: 0.93911                                                                                                     
Val loss: 1.3762 ACC: 0.79222                                                                                                    


Epoch 60:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2284/1.2234 Drop: 0.79825 ACC: 0.93911                                                                                                     
Val loss: 1.3769 ACC: 0.78858                                                                                                    


Epoch 61:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2267/1.2548 Drop: 0.80019 ACC: 0.9414                                                                                                      
Val loss: 1.3749 ACC: 0.79101                                                                                                    


Epoch 62:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2255/1.1665 Drop: 0.80097 ACC: 0.94232                                                                                                     
Val loss: 1.3758 ACC: 0.78736                                                                                                    


Epoch 63:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2255/1.2251 Drop: 0.80032 ACC: 0.94155                                                                                                     
Val loss: 1.3768 ACC: 0.78736                                                                                                    
Epoch    63: reducing learning rate of group 0 to 4.0725e-03.


Epoch 64:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2247/1.2584 Drop: 0.80123 ACC: 0.94262                                                                                                     
Val loss: 1.3764 ACC: 0.78615                                                                                                    


Epoch 65:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2235/1.178 Drop: 0.80382 ACC: 0.94567                                                                                                      
Val loss: 1.3766 ACC: 0.79101                                                                                                    


Epoch 66:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2263/1.2948 Drop: 0.80123 ACC: 0.94262                                                                                                     
Val loss: 1.3779 ACC: 0.78615                                                                                                    


Epoch 67:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2192/1.1994 Drop: 0.80668 ACC: 0.94903                                                                                                     
Val loss: 1.3785 ACC: 0.78493                                                                                                    


Epoch 68:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2227/1.207 Drop: 0.80473 ACC: 0.94674                                                                                                      
Val loss: 1.3798 ACC: 0.78615                                                                                                    


Epoch 69:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2247/1.3244 Drop: 0.80304 ACC: 0.94476                                                                                                     
Val loss: 1.3806 ACC: 0.78129                                                                                                    
Epoch    69: reducing learning rate of group 0 to 3.8689e-03.


Epoch 70:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2238/1.2507 Drop: 0.80162 ACC: 0.94308                                                                                                     
Val loss: 1.3811 ACC: 0.7825                                                                                                     


Epoch 71:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2196/1.171 Drop: 0.80564 ACC: 0.94781                                                                                                      
Val loss: 1.3816 ACC: 0.78372                                                                                                    


Epoch 72:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2245/1.247 Drop: 0.80266 ACC: 0.9443                                                                                                       
Val loss: 1.3814 ACC: 0.78493                                                                                                    


Epoch 73:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.22/1.1826 Drop: 0.80616 ACC: 0.94842                                                                                                       
Val loss: 1.3821 ACC: 0.7825                                                                                                     


Epoch 74:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2204/1.2247 Drop: 0.80538 ACC: 0.9475                                                                                                      
Val loss: 1.3825 ACC: 0.7825                                                                                                     


Epoch 75:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2195/1.205 Drop: 0.80655 ACC: 0.94888                                                                                                      
Val loss: 1.3816 ACC: 0.78372                                                                                                    
Epoch    75: reducing learning rate of group 0 to 3.6755e-03.


Epoch 76:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2198/1.2059 Drop: 0.80564 ACC: 0.94781                                                                                                     
Val loss: 1.382 ACC: 0.78129                                                                                                     


Epoch 77:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2215/1.1853 Drop: 0.80538 ACC: 0.9475                                                                                                      
Val loss: 1.3832 ACC: 0.78129                                                                                                    


Epoch 78:   0%|          | 0/7376 [00:00<?, ?it/s]

Train loss: 1.2177/1.1689 Drop: 0.80784 ACC: 0.9504                                                                                                      
Val loss: 1.3819 ACC: 0.7825                                                                                                     
Best Val loss: 1.3733                                                                                                    


In [None]:
tokenizer.vocab_size

In [None]:
# Epoch 69: 100%
# Train loss: 2.1202/2.0784 Drop: 0.81603 ACC: 0.96003                                                                                                     
# Val   loss: 2.1766 ACC: 0.90698 

In [None]:
# batchsize = 64

# 20ng (maxdrop=0.85)
# Val loss: 2.1755 ACC: 0.90645                                                                                                    
# Test loss: 2.1664 ACC: 0.91226

# acm (maxdrop=0.85)
# Val loss: 1.7571 ACC: 0.78798                                                                                                    
# Test loss: 1.7697 ACC: 0.78717                                                                                                    

# webkb (maxdrop=0.85)
# Val loss: 1.3658 ACC: 0.79587
# Test loss: 1.3602 ACC: 0.80316                                                                       


In [23]:
device_test = 'cpu'
ab = copy.deepcopy(best_model).to(device_test)
ab.eval()
loss_total = 0
correct_t = 0
total_t = 0
dl_test = DataLoader(list(zip(fold.X_test, y_test)), batch_size=128,
                         shuffle=False, collate_fn=collate_train, num_workers=num_workers)
for i, (terms_idx_t, docs_offsets_t, y_t) in enumerate(dl_test):
    terms_idx_t    = terms_idx_t.to( device_test )
    docs_offsets_t = docs_offsets_t.to( device_test )
    y_t            = y_t.to( device_test )

    pred_docs_t,weigths,coweights = ab( terms_idx_t, docs_offsets_t )
    sofmax_docs_t = F.softmax(pred_docs_t)

    y_pred_t    = sofmax_docs_t.argmax(axis=1)
    correct_t  += (y_pred_t == y_t).sum().item()
    total_t    += len(y_t)
    loss_total += loss_func_cel(sofmax_docs_t, y_t)

    print(f'Test loss: {loss_total.item()/(i+1):.5} ACC: {correct_t/total_t:.5}', end=f"{' '*100}\r")

Test loss: 1.3677 ACC: 0.79222                                                                                                    

In [None]:
(y_pred_t == y_t).sum().item(), len(y_t), (y_pred_t == y_t).sum().item()/len(y_t)

In [None]:
weigths[:,:,0].shape, coweights.shape, pred_docs_t.shape

In [None]:
weigths[:,:,0]

In [None]:
class NopeAttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens, nclass, drop=.5, initrange=.5, negative_slope=99.):
        super(SimpleAttentionBag, self).__init__()
        self.hiddens        = hiddens
        self.dt_emb         = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.tt_s_emb       = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.tt_t_emb       = nn.Embedding(vocab_size, hiddens, scale_grad_by_freq=True, padding_idx=0)
        self.fc             = nn.Linear(hiddens, nclass)
        #self.fc             = nn.Sequential(
        #    nn.Linear(hiddens, hiddens),
        #    nn.Sigmoid(),
        #    nn.Linear(hiddens, nclass)
        #)
        self.initrange      = initrange 
        self.negative_slope = negative_slope
        self.drop           = nn.Dropout(drop)
        self.norm           = nn.BatchNorm1d(hiddens)
        self.drop_          = drop
        self.sig            = nn.Sigmoid()
        self.init_weights()
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h     = self.dt_emb( x_packed )
        dt_h     = F.dropout( dt_h, p=self.drop_, training=self.training )
        
        tt_h     = self.tt_s_emb( x_packed )
        tt_h     = F.tanh(tt_h)
        tt_h     = F.dropout( tt_h, p=self.drop_, training=self.training )
        
        tt_dir_h = self.tt_t_emb( x_packed )
        tt_dir_h = F.tanh(tt_dir_h)
        tt_dir_h = F.dropout( tt_dir_h, p=self.drop_, training=self.training )
        
        co_weights = torch.bmm( tt_h, tt_dir_h.transpose( 1, 2 ) )
        co_weights[pad_mask.logical_not()] = 0. # Set the 3D-pad mask values to -inf (=0 in sigmoid)
        co_weights = F.tanh(co_weights)
        
        weights = co_weights.sum(axis=2) / doc_sizes
        weights[bx_packed] = float('-inf') # Set the 2D-pad mask values to -inf  (=0 in softmax)
        weights = F.softmax(weights, dim=1)
        weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
        weights = weights.view( *weights.shape, 1 )
        
        docs_h = dt_h * weights
        docs_h = docs_h.sum(axis=1)
        docs_h = F.dropout( docs_h, p=self.drop_, training=self.training )
        docs_h = self.fc(docs_h)
        return docs_h, weights, co_weights
    
    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_s_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_t_emb.weight.data.uniform_(-self.initrange, self.initrange)
        #self.fc.weight.data.uniform_(-self.initrange, self.initrange)

In [None]:
coweights.round()

In [None]:
batch_size = docs_offsets_t.shape[0]
        
k         = [ terms_idx_t[ docs_offsets_t[i-1]:docs_offsets_t[i] ] for i in range(1, batch_size) ]
k.append( terms_idx_t[ docs_offsets_t[-1]: ] )
x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

x_packed  

In [None]:
x_packed[0]

In [None]:
import networkx as nx

In [None]:
for i in range(len(y_t)):
    if (y_pred_t == y_t)[i]: 
        #print(i, y_t[i].item(), sofmax_docs_t[i,y_t[i]].item())
        n = (x_packed[i]!=0).sum()
        print(i, y_t[i].item(), n.item(), ((coweights[i, :n, :n] > 0.5).sum()/(n*n)).item(), sofmax_docs_t[i,y_t[i]].item()/sofmax_docs_t[i].max().item())
        #print(coweights[i, :n, :n] != 0)
        #print(x_packed[i, :n], x_packed[i, :n].sum()/(n*n))

In [None]:
from scipy import sparse as sp 

In [None]:
i = 9
n = (x_packed[i]!=0).sum()
dt_h = ab.dt_emb( x_packed[i] )
coweights[i, :n, :n] >= 0.5

In [None]:
pca = TSNE(n_components=2)
dt_h_np = np.array(dt_h.tolist())
x2 = pca.fit_transform(dt_h_np[:n,:])
pos = { i: v for (i,v) in enumerate(x2) }

In [None]:
weigths[i,:n,0]

In [None]:
sp_matrix = sp.csc_matrix(((coweights[i, :n, :n]>.5)*coweights[i, :n, :n]).tolist())
G = nx.from_scipy_sparse_matrix(sp_matrix, create_using=nx.DiGraph)

In [None]:
inv_dict= { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
fig = plt.figure(figsize=(10,10))
#pos=nx.spring_layout(G)
#pos=nx.spiral_layout(G)

doc_weigs = np.array(weigths[i,:n,0].tolist())
doc_weigs /= doc_weigs.mean(axis=0)
doc_weigs = (doc_weigs - doc_weigs.min(axis=0)) / (doc_weigs.max(axis=0) - doc_weigs.min(axis=0))

node_sizes  = [ doc_weigs[nid]*600 for nid in G.nodes ]
node_labels = { nid: inv_dict[nid] for nid in G.nodes }
#edge_alphas = { nid: 1.-doc_weigs[nid] for nid in G.nodes }
edge_widths = { nid: doc_weigs[nid]*3. for nid in G.nodes }

ax = nx.draw_networkx_nodes(G, pos=pos, node_size=node_sizes)
ax = nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10)
ax = nx.draw_networkx_edges(G, pos=pos, edge_color='darkgray', width=edge_widths, alpha=0.1)
#for nid, alpha in tqdm(edge_alphas.items(), total=len(edge_alphas)):
#    ax = nx.draw_networkx_edges(G, pos=pos, nodelist=[nid], edge_color='darkgray', width=edge_widths, alpha=alpha)

In [None]:
doc_weigs

In [None]:
doc_weigs, doc_weigs2

In [None]:
G = nx.DiGraph()

In [None]:
len(sp_matrix.nonzero()[0])

In [None]:
"""
acm ####################################################################################
Train loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    
Val loss: 1.7718 ACC: 0.77475                                                                                                    
New Best Val loss: 1.7718                                                                                                    
Test loss: 1.7678 ACC: 0.78557  79.92

Train loss: 1.7209/1.6095 ACC: 0.82338                                                                                                    
Val loss: 1.7595 ACC: 0.78236                                                                                                    
New Best Val loss: 1.7595                                                                                             
Test loss: 1.7585 ACC: 0.78557                                                                                                    

20ng ####################################################################################
Train loss: 2.0907/2.0787 ACC: 0.98845                                                                                                    
Val loss: 2.1869 ACC: 0.90803                                                                                                    
New Best Val loss: 2.1869                                                                                                    
Test loss: 2.178 ACC: 0.91068   92.65

Train loss: 2.1603/2.1249 ACC: 0.92511                                                                                                    
Val loss: 2.1842 ACC: 0.90645                                                                                                    
drop: 0.8436
New Best Val loss: 2.1842                                                                                                   
Test loss: 2.1705 ACC: 0.91226                                                                                                    

reut ####################################################################################
Train loss: 3.7735/3.5191 ACC: 0.74734            acm                                                                                        
Val loss: 3.8554 ACC: 0.6763                                                                                                    
New Best Val loss: 3.8554                                                                                                    
Test loss: 3.8493 ACC: 0.6837  72.67

Train loss: 3.7443/3.8521 ACC: 0.77727                                                                                                    
Val loss: 3.808 ACC: 0.71037                                                                                                     
New Best Val loss: 3.808
Test loss: 3.8461 ACC: 0.70444                                                                                                    

webkb ####################################################################################
Train loss: 1.2228/1.2037 ACC: 0.9504                                                                                                     
Val loss: 1.3787 ACC: 0.80316                                                                                                    
New Best Val loss: 1.3787                                                                                                    
Test loss: 1.3857 ACC: 0.78858   81.53

Epoch: 45
Train loss: 1.2392/1.2909 ACC: 0.9295                                                                                                     
Val loss: 1.3838 ACC: 0.78736                                                                                                    
New Best Val loss: 1.3838
Test loss: 1.3814 ACC: 0.78372                                                                                                    

"""

In [None]:
dt_h = ab.dt_emb( x_packed )
dt_h.shape

In [None]:
bx_packed = x_packed == 0
doc_sizes = bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
pad_mask  = bx_packed.logical_not()
pad_mask  = pad_mask.view(*bx_packed.shape, 1)
pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
pad_mask.shape, doc_sizes.shape, bx_packed.shape 

In [None]:
doc_means = (bx_packed.logical_not().view(*bx_packed.shape, 1) * dt_h).sum(axis=1) / doc_sizes
doc_means = doc_means.view(*doc_means.shape, 1).transpose(1,2)
doc_means.shape

In [None]:
n_dt_h = dt_h/doc_sizes.view(*doc_sizes.shape, 1)
n_dt_h = (doc_means - n_dt_h)
n_dt_h.shape

In [None]:
y_pred_t      = pred_docs_t.argmax(axis=1)
correct_t     = (y_pred_t == y_t).sum().item()
total_t       = len(y_t)
correct_t/total_t

In [None]:
docs_offsets_t

In [None]:
terms_idx_t[:docs_offsets_t[batch_size]]

In [None]:
batch_off_test  = docs_offsets_t[batch_size]
batch_tidx_test = terms_idx_t[:batch_off_test]
h_terms_test    = sc.tt_emb( batch_tidx_test )
dirh_terms_test = sc.tt_dir_map( h_terms_test )

W = torch.matmul( h_terms_test, dirh_terms_test.T )
W = F.leaky_relu( W, negative_slope=sc.mask.negative_slope)
W = F.sigmoid(W)
W

In [None]:
k = [ batch_tidx_test[ docs_offsets_t[i-1]:docs_offsets_t[i] ] for i in range(1, batch_size) ]
k.append( batch_tidx_test[ docs_offsets_t[batch_size-1]:docs_offsets_t[batch_size] ] )
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )
len(k)

In [None]:
x_packed = pad_sequence(k, batch_first=True, padding_value=0)
tt_emb = sc.tt_emb( x_packed )

In [None]:
tt_emb.transpose(0,1)

In [None]:
a = [terms_idx, terms_idx]
torch.stack(a)

In [None]:
F.softmax(pred_docs_t).argmax(axis=1)

In [None]:
y_val

In [None]:
(y_t == F.softmax(pred_docs_t).argmax(axis=1)).sum().item()/y_t.shape[0]

In [None]:
terms_idx_t, docs_offsets_t

In [None]:
shifts = sc._get_shift_(docs_offsets_t, terms_idx_t.shape[0])

In [None]:
zipado = zip(docs_offsets_t, shifts)
#next(zipado)
start,size = next(zipado)

In [None]:
w = sc.tt_emb( terms_idx_t[start:start+size] )
w1 = sc.tt_dir_map( w )
w = torch.matmul( w, w1.T )
w = F.leaky_relu( w, negative_slope=sc.negative_slope)
w = F.sigmoid(w)
#w = F.tanh(w)
#w = F.relu(w)
#w = w.mean(axis=1)
#w = F.softmax(w)
w,w1

In [None]:
w.round().sum() / (w.shape[0]*w.shape[1])

In [None]:
inv_mapper = { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
terms_idx[start:start+size]

In [None]:
fold.X_val[0]

In [None]:
bla = w.mean(axis=1)
bla = F.softmax(bla)
#bla = bla/torch.clamp(bla.sum(), 0.0001)
bla

In [None]:
[ (i, tid.item(),inv_mapper[tid.item()], wei.item()) for i, (tid, wei) in enumerate(zip(terms_idx[start:start+size], bla)) ]

In [None]:
w

In [None]:
w.mean(axis=1)

In [None]:
w.shape

In [None]:
norm = nn.BatchNorm1d(num_features=1).to(device)

bla2 = norm(bla.view(-1, 1)).squeeze()
bla2 = F.sigmoid(bla2)
bla2

In [None]:
1

In [None]:
ma = nn.MultiheadAttention(300, 300).to(device)
ma

In [None]:
torch.__version__

In [None]:
a,b = w.shape
w_ = w.view(a,1,b)
w1_ = w1.view(a,1,b)

attn_output = ma(w_, w1_, w_, need_weights=False)

attn_output.view(a,b)
attn_output_weights.view(a,a)

In [None]:
attn_output.view(a,b).shape

In [None]:
attn_output_weights.view(a,a)

In [None]:
F.softmax(torch.Tensor([[0.0718, 0.0716, 0.0712, 0.0714, 0.0721, 0.0710, 0.0712, 0.0714, 0.0710,
         0.0719, 0.0711, 0.0712, 0.0718, 0.0714],
        [0.0722, 0.0709, 0.0710, 0.0709, 0.0728, 0.0723, 0.0709, 0.0709, 0.0719,
         0.0712, 0.0709, 0.0707, 0.0725, 0.0707],
        [0.0711, 0.0715, 0.0710, 0.0713, 0.0710, 0.0712, 0.0718, 0.0713, 0.0709,
         0.0725, 0.0716, 0.0719, 0.0710, 0.0719],
        [0.0721, 0.0717, 0.0714, 0.0713, 0.0717, 0.0714, 0.0711, 0.0710, 0.0713,
         0.0717, 0.0710, 0.0712, 0.0720, 0.0711],
        [0.0722, 0.0711, 0.0710, 0.0710, 0.0737, 0.0716, 0.0709, 0.0705, 0.0714,
         0.0719, 0.0707, 0.0707, 0.0731, 0.0702],
        [0.0714, 0.0716, 0.0708, 0.0711, 0.0718, 0.0713, 0.0712, 0.0717, 0.0712,
         0.0724, 0.0713, 0.0712, 0.0716, 0.0714],
        [0.0712, 0.0714, 0.0713, 0.0713, 0.0722, 0.0716, 0.0709, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0713, 0.0716, 0.0713],
        [0.0716, 0.0707, 0.0712, 0.0714, 0.0717, 0.0715, 0.0714, 0.0715, 0.0712,
         0.0721, 0.0711, 0.0714, 0.0717, 0.0715],
        [0.0717, 0.0713, 0.0708, 0.0710, 0.0725, 0.0717, 0.0714, 0.0709, 0.0712,
         0.0712, 0.0712, 0.0717, 0.0720, 0.0714],
        [0.0709, 0.0712, 0.0713, 0.0712, 0.0713, 0.0713, 0.0714, 0.0719, 0.0712,
         0.0724, 0.0715, 0.0715, 0.0717, 0.0713],
        [0.0715, 0.0716, 0.0712, 0.0708, 0.0725, 0.0717, 0.0712, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0706, 0.0718, 0.0712],
        [0.0726, 0.0711, 0.0713, 0.0706, 0.0737, 0.0721, 0.0709, 0.0707, 0.0714,
         0.0708, 0.0706, 0.0705, 0.0730, 0.0707],
        [0.0718, 0.0711, 0.0714, 0.0715, 0.0725, 0.0707, 0.0707, 0.0711, 0.0712,
         0.0728, 0.0711, 0.0713, 0.0719, 0.0709],
        [0.0712, 0.0716, 0.0713, 0.0712, 0.0717, 0.0707, 0.0708, 0.0721, 0.0709,
         0.0728, 0.0713, 0.0716, 0.0711, 0.0715]]).sum(axis=0))

In [None]:
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout1=0.1, dropout2=0.1, negative_slope=99,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.undirected_map = nn.Linear(hidden_l, hidden_l)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop1 = nn.Dropout(dropout1)
        self.drop2 = nn.Dropout(dropout2)
        
        self.norm = nn.BatchNorm1d(1)
        
        self.initrange = initrange
        self.nclass = nclass
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        
        terms_h1 = self.tt_emb(terms_idxs)
        terms_h1 = self.drop1(terms_h1)
        
        terms_h2 = self.undirected_map( terms_h1 )
        #terms_h2 = self.drop1( terms_h2 )
        for start,size in zip(docs_offsets, shifts):
            w  = terms_h1[start:start+size]
            w1 = terms_h2[start:start+size]
            w = torch.matmul( w, w1.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w-5.5)
            w = w.mean(axis=1)
            w = F.softmax(w)
            #w = w / torch.clamp(w.sum(), 0.0001)
            weights.append( w )
        
        weights = torch.cat(weights)
        #weights = self.norm(weights.view(-1, 1)).squeeze()
        #weights = F.sigmoid(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop2( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])