In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset

from tqdm.notebook import tqdm
from TGA.utils import preprocessor

from time import time
import numpy as np
from itertools import repeat
from collections import Counter
from segtok import tokenizer as tk

from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction import stop_words
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [3]:
torch.__version__

'1.7.1'

In [4]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/acm/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 19907)

In [5]:
class Tokenizer(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, stopwords='remove', model='list', lan='english', verbose=False):
        super(Tokenizer, self).__init__()
        self.mindf = mindf
        self.le = LabelEncoder()
        self.verbose = verbose
        self.stopwords = stopwords
        self.stopwordsSet = set(stop_words.ENGLISH_STOP_WORDS)
        self.lan = lan
        self.model = model
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor).build_analyzer()
        #self.analyzer = tk.web_tokenizer
    
    def fit(self, X, y):
        self.N = len(X)
        self.le.fit( y )
        self.n_class = len(self.le.classes_)

        self.term_freqs = Counter()
        docs = map(self.analyzer, X)
        for doc_in_terms in tqdm(docs, total=self.N, disable=not self.verbose):
            doc_in_terms = list(map( self._filter_fit_, doc_in_terms ))
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self._get_idx_(term) for term in self.term_freqs.keys() if self._isrel_(term) }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.vocab_size = len(self.node_mapper)
        
        return self
    def _isrel_(self, term):
        if self.stopwords == 'remove' and term in self.stopwordsSet:
            return False
        # put here your filter_functions
        return True
    def _get_idx_(self, term):
        # put here your idx_set_functions
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            print('is stop', term)
            return self.node_mapper.setdefault('<STPW>', len(self.node_mapper))
        return self.node_mapper.setdefault(term, len(self.node_mapper))
    def _filter_transform_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        if term not in self.node_mapper:
            return '<UNK>'
        return term
    def _filter_fit_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        return term
    def _model_(self, doc):
        if self.model == 'set':
            return set(doc)
        return list(doc)
    def transform(self, X, verbose=None):
        verbose = verbose if verbose is not None else self.verbose
        n = len(X)
        doc_off = [0]
        terms_idx = []
        for i,doc_in_terms in tqdm(enumerate(map(self.analyzer, X)), total=n, disable=not verbose):
            doc_in_terms = filter( self._isrel_, doc_in_terms )
            doc_in_terms = map( self._filter_transform_, doc_in_terms )
            doc_in_terms = self._model_(doc_in_terms)
            doc_in_terms = [ self.node_mapper[tid] for tid in doc_in_terms ]
            if self.model == 'sorted':
                doc_in_terms = sorted(doc_in_terms)
            doc_off.append( len(doc_in_terms) )
            terms_idx.extend( doc_in_terms )
        return np.array( terms_idx ), np.array(doc_off)[:-1].cumsum()

In [6]:
tokenizer = Tokenizer(mindf=1, stopwords='keep', model='set', verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)

HBox(children=(FloatProgress(value=0.0, max=19907.0), HTML(value='')))




Tokenizer(mindf=1, model='set', stopwords='keep', verbose=True)

In [7]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )
y_test  = tokenizer.le.transform( fold.y_test )

In [8]:
tokenizer.transform(fold.X_val[:2])

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




(array([  240,    80,  1723,  1708, 49334,   634,  2954,   203, 12227,
         4533,    11,  2457,    38,   382,   962,    41,    46,  7125,
          965,   165,   636,    53,  2067,    18,   325,  7660,   282,
           56,   236,    58,  7313,  3028,    62,    21, 11004,    65,
        11883,   474,  1384,  5044,   336,   212, 49334,   298, 15089,
          429,  3083,   852,    80,   808,  2817,  2684,   106,  1392,
         2166,   539,  8644,  4140,    29,   155,   181,  3112]),
 array([ 0, 10]))

In [9]:
tokenizer.node_mapper['<UNK>'], fold.X_val[0]

(49334,
 'how do computer science lecturers create modules? (poster) john traxler  \n')

In [10]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [11]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=0.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
    def forward(self, h, h2=None):
        if h2 is None:
            h2 = h
        w = torch.matmul( h, h2.T )
        w = F.leaky_relu( w, negative_slope=self.negative_slope)
        w = F.sigmoid(w-self.kappa)
        return w

In [12]:
class AttClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, n_heads=8, dropout=0.1, negative_slope=1000,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(AttClassifier, self).__init__()
        
        self.drop = nn.Dropout(dropout)
        self.hidden_l = hidden_l
        
        self.dt_emb     = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.dt_dir_map = nn.Linear(hidden_l, hidden_l)
        #self.ma_term    = nn.MultiheadAttention(hidden_l, n_heads, dropout=dropout)
        
        self.tt_dir_map = nn.Linear(hidden_l, hidden_l)
        self.tt_emb     = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.fc = nn.Linear(hidden_l, nclass)
        
        self.learn = nn.Bilinear( hidden_l, hidden_l, hidden_l )
        
        self.initrange = initrange
        self.nclass = nclass
        
        self.mask = Mask( negative_slope )
        
        self.init_weights()
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        shifts = self._get_shift_(docs_offsets, n)
        
        tt_h     = self.tt_emb(terms_idxs)
        tt_h     = self.drop(tt_h)
        dir_tt_h = self.tt_dir_map( tt_h )
        
        dt_h     = self.dt_emb(terms_idxs)
        dt_h     = self.drop(dt_h)
        #dir_dt_h = self.dt_dir_map( dt_h )
        
        weights = []
        docs = []
        for start,size in zip(docs_offsets, shifts):
            ### Co-occurence model
            w = self.mask( tt_h[start:start+size], dir_tt_h[start:start+size] )
            
            w = w.mean(axis=1)
            w = F.softmax(w)
            
            weights.append( w )
            
            ### Document model
            #term_emb = dt_h[start:start+size].view(size,1,self.hidden_l)
            #dir_term_emb = dir_dt_h[start:start+size].view(doc_size,1,h_dims)

            #attn_output, attn_w = self.ma_term(term_emb, dir_term_emb, term_emb, need_weights=False)
            #attn_output = attn_output.view(doc_size, h_dims)
            
            #term_emb = term_emb.view(size, self.hidden_l)
            
            #docs.append( term_emb  )
        
        weights       = torch.cat(weights)
        """
        docs          = torch.cat(docs)
        new_terms_idx = torch.LongTensor(range(n)).to(terms_idxs.device)
        
        h_docs  = F.embedding_bag(docs, new_terms_idx, docs_offsets,
                                  per_sample_weights=weights, mode='sum')
        """
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets,
                                  per_sample_weights=weights, mode='sum')
        
        h_docs = self.drop( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        #self.dt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        
        #self.ma_term.in_proj_weight.data.uniform_(-self.initrange, self.initrange)
        
        self.tt_dir_map.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.fc.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])

In [13]:
nepochs = 25
max_epochs = 10
drop=0.75
device = torch.device('cuda:0')
batch_size = 128

In [14]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
sc = AttClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, n_heads=15, dropout=drop).to( device )

optimizer = optim.AdamW( sc.parameters(), lr=7e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.9,
#                                                       patience=1, verbose=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=.95, verbose=True)

Adjusting learning rate of group 0 to 7.0000e-03.


In [15]:
best = 99999.
counter = 1
dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=len(y_val),
                         shuffle=False, collate_fn=collate_train, num_workers=4)
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=4)
    total_loss  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        sc.train()
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs = sc( terms_idx, docs_offsets)
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            
            toprint  = f"Train loss: {total_loss/(i+1):.5}/{loss.item():.5} "
            toprint += f'ACC: {correct/total:.5}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            if e < (nepochs-1):
                del pred_docs, loss
            del terms_idx, docs_offsets, y
            del y_pred
            
        total = 0
        correct  = 0
        sc.eval()
        with torch.no_grad():
            print()
            for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
                terms_idx    = terms_idx.to( device )
                docs_offsets = docs_offsets.to( device )
                y            = y.to( device )

                pred_docs = sc( terms_idx, docs_offsets )
                pred_docs = F.softmax(pred_docs)

                y_pred      = pred_docs.argmax(axis=1)
                correct    += (y_pred == y).sum().item()
                total      += len(y)
                loss2 = loss_func_cel(pred_docs, y)

                print(f'Val loss: {loss2.item():.5} ACC: {correct/total:.5}', end=f"{' '*100}\r")

                pbar.update( len(y) )

            del terms_idx, docs_offsets, y
            del y_pred
            #scheduler.step(loss2)
            print()
            scheduler.step()

            if best-loss2.item() > 0.0001 :
                best = loss2.item()
                counter = 1
                print()
                print(f'New Best Val loss: {best:.5}', end=f"{' '*100}\n")

                dl_test = DataLoader(list(zip(fold.X_test, y_test)), batch_size=len(y_test),
                                         shuffle=False, collate_fn=collate_train, num_workers=2)
                for i, (terms_idx_t, docs_offsets_t, y_t) in enumerate(dl_test):
                    terms_idx_t    = terms_idx_t.to( device )
                    docs_offsets_t = docs_offsets_t.to( device )
                    y_t            = y_t.to( device )

                    pred_docs_t = sc( terms_idx_t, docs_offsets_t )
                    pred_docs_t = F.softmax(pred_docs_t)

                    y_pred_t      = pred_docs_t.argmax(axis=1)
                    correct_t     = (y_pred_t == y_t).sum().item()
                    total_t       = len(y_t)
                    loss2_t = loss_func_cel(pred_docs_t, y_t)

                    print(f'Test loss: {loss2_t.item():.5} ACC: {correct_t/total_t:.5}', end=f"{' '*100}\r")
            elif counter > max_epochs:
                print()
                print(f'Best Val loss: {best:.5}', end=f"{' '*100}\n")
                break
            else:
                counter += 1
            del pred_docs, loss2

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 2.1113/1.8681 ACC: 0.47009                                                                                                    
Val loss: 1.9093 ACC: 0.68337                                                                                                    
Adjusting learning rate of group 0 to 7.0000e-03.

New Best Val loss: 1.9093                                                                                                    
Test loss: 1.9007 ACC: 0.68818                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.8375/1.8142 ACC: 0.7422                                                                                                     
Val loss: 1.8303 ACC: 0.73226                                                                                                    
Adjusting learning rate of group 0 to 6.6500e-03.

New Best Val loss: 1.8303                                                                                                    
Test loss: 1.8222 ACC: 0.73828                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.7465/1.7068 ACC: 0.81876                                                                                                    
Val loss: 1.803 ACC: 0.76232                                                                                                    
Adjusting learning rate of group 0 to 6.6500e-03.

New Best Val loss: 1.803                                                                                                    
Test loss: 1.7912 ACC: 0.76834                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6964/1.6495 ACC: 0.86507                                                                                                    
Val loss: 1.7912 ACC: 0.77034                                                                                                    
Adjusting learning rate of group 0 to 6.3175e-03.

New Best Val loss: 1.7912                                                                                                    
Test loss: 1.7805 ACC: 0.77876                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6619/1.6095 ACC: 0.89496                                                                                                    
Val loss: 1.7869 ACC: 0.76834                                                                                                    
Adjusting learning rate of group 0 to 6.3175e-03.

New Best Val loss: 1.7869                                                                                                    
Test loss: 1.7764 ACC: 0.77715                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6396/1.5773 ACC: 0.91435                                                                                                    
Val loss: 1.7853 ACC: 0.76713                                                                                                    
Adjusting learning rate of group 0 to 6.0016e-03.

New Best Val loss: 1.7853                                                                                                    
Test loss: 1.7741 ACC: 0.77796                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6272/1.6646 ACC: 0.92324                                                                                                    
Val loss: 1.7828 ACC: 0.77114                                                                                                    
Adjusting learning rate of group 0 to 6.0016e-03.

New Best Val loss: 1.7828                                                                                                    
Test loss: 1.7723 ACC: 0.78236                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6157/1.6278 ACC: 0.93409                                                                                                    
Val loss: 1.786 ACC: 0.76513                                                                                                    
Adjusting learning rate of group 0 to 5.7015e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.6076/1.585 ACC: 0.94103                                                                                                     
Val loss: 1.7849 ACC: 0.76874                                                                                                    
Adjusting learning rate of group 0 to 5.7015e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.6008/1.5781 ACC: 0.94685                                                                                                    
Val loss: 1.7849 ACC: 0.76473                                                                                                    
Adjusting learning rate of group 0 to 5.4165e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5961/1.6366 ACC: 0.95147                                                                                                    
Val loss: 1.7848 ACC: 0.76633                                                                                                    
Adjusting learning rate of group 0 to 5.4165e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5928/1.566 ACC: 0.95358                                                                                                     
Val loss: 1.7833 ACC: 0.76713                                                                                                    
Adjusting learning rate of group 0 to 5.1456e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5871/1.5943 ACC: 0.95901                                                                                                    
Val loss: 1.7818 ACC: 0.76994                                                                                                    
Adjusting learning rate of group 0 to 5.1456e-03.

New Best Val loss: 1.7818                                                                                                    
Test loss: 1.7745 ACC: 0.77756                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5843/1.551 ACC: 0.96142                                                                                                     
Val loss: 1.7842 ACC: 0.76633                                                                                                    
Adjusting learning rate of group 0 to 4.8884e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5838/1.6193 ACC: 0.96238                                                                                                    
Val loss: 1.7853 ACC: 0.76633                                                                                                    
Adjusting learning rate of group 0 to 4.8884e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5798/1.5597 ACC: 0.96599                                                                                                    
Val loss: 1.7844 ACC: 0.76513                                                                                                    
Adjusting learning rate of group 0 to 4.6439e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5785/1.5602 ACC: 0.9667                                                                                                     
Val loss: 1.7854 ACC: 0.76433                                                                                                    
Adjusting learning rate of group 0 to 4.6439e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5759/1.5734 ACC: 0.96946                                                                                                    
Val loss: 1.7828 ACC: 0.76593                                                                                                    
Adjusting learning rate of group 0 to 4.4117e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5744/1.5684 ACC: 0.97061                                                                                                    
Val loss: 1.7808 ACC: 0.76754                                                                                                    
Adjusting learning rate of group 0 to 4.4117e-03.

New Best Val loss: 1.7808                                                                                                    
Test loss: 1.774 ACC: 0.77675                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5736/1.5738 ACC: 0.97137                                                                                                    
Val loss: 1.7813 ACC: 0.76433                                                                                                    
Adjusting learning rate of group 0 to 4.1912e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.572/1.6136 ACC: 0.97313                                                                                                     
Val loss: 1.7816 ACC: 0.76553                                                                                                    
Adjusting learning rate of group 0 to 4.1912e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5704/1.5581 ACC: 0.97448                                                                                                    
Val loss: 1.7825 ACC: 0.76513                                                                                                    
Adjusting learning rate of group 0 to 3.9816e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5697/1.5738 ACC: 0.97493                                                                                                    
Val loss: 1.7805 ACC: 0.76754                                                                                                    
Adjusting learning rate of group 0 to 3.9816e-03.

New Best Val loss: 1.7805                                                                                                    
Test loss: 1.7748 ACC: 0.77475                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5692/1.5584 ACC: 0.97539                                                                                                    
Val loss: 1.7806 ACC: 0.76834                                                                                                    
Adjusting learning rate of group 0 to 3.7825e-03.



HBox(children=(FloatProgress(value=0.0, description='Epoch 25', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.5681/1.544 ACC: 0.97624                                                                                                     
Val loss: 1.7788 ACC: 0.76874                                                                                                    
Adjusting learning rate of group 0 to 3.7825e-03.

New Best Val loss: 1.7788                                                                                                    
Test loss: 1.7756 ACC: 0.77475                                                                                                    



In [16]:
y_pred_t      = pred_docs_t.argmax(axis=1)
correct_t     = (y_pred_t == y_t).sum().item()
total_t       = len(y_t)
correct_t/total_t

0.774749498997996

In [17]:
docs_offsets_t

tensor([    0,    14,    44,  ..., 88569, 88575, 88579], device='cuda:0')

In [18]:
terms_idx_t[:docs_offsets_t[batch_size]]

tensor([1946, 1959, 1963,  ...,  525, 2726, 4074], device='cuda:0')

In [19]:
batch_off_test  = docs_offsets_t[batch_size]
batch_tidx_test = terms_idx_t[:batch_off_test]
h_terms_test    = sc.tt_emb( batch_tidx_test )
dirh_terms_test = sc.tt_dir_map( h_terms_test )

W = torch.matmul( h_terms_test, dirh_terms_test.T )
W = F.leaky_relu( W, negative_slope=sc.mask.negative_slope)
W = F.sigmoid(W)
W

tensor([[1.0000, 0.5159, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9715, 0.0000,  ..., 0.9953, 1.0000, 1.0000],
        [0.9999, 1.0000, 0.0000,  ..., 0.6067, 0.9835, 1.0000],
        ...,
        [0.0000, 0.0000, 0.9590,  ..., 1.0000, 0.0000, 0.0000],
        [0.9971, 0.0000, 0.0000,  ..., 1.0000, 1.0000, 0.0000],
        [0.9999, 0.0000, 0.6372,  ..., 1.0000, 0.0000, 0.9284]],
       device='cuda:0', grad_fn=<SigmoidBackward>)

In [26]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [31]:
x_packed = pad_sequence(batch_tidx_test, batch_first=True, padding_value=0)

IndexError: dimension specified as 0 but tensor has no dimensions

In [30]:
batch_tidx_test

tensor([1946, 1959, 1963,  ...,  525, 2726, 4074], device='cuda:0')

In [None]:
"""
acm ####################################################################################
Train loss: 1.6009/1.6089 ACC: 0.94886                                                                                                    
Val loss: 1.7718 ACC: 0.77475                                                                                                    
New Best Val loss: 1.7718                                                                                                    
Test loss: 1.7678 ACC: 0.78557  79.92

20ng ####################################################################################
Train loss: 2.0907/2.0787 ACC: 0.98845                                                                                                    
Val loss: 2.1869 ACC: 0.90803                                                                                                    
New Best Val loss: 2.1869                                                                                                    
Test loss: 2.178 ACC: 0.91068   92.65

reut ####################################################################################
Train loss: 3.7735/3.5191 ACC: 0.74734                                                                                                    
Val loss: 3.8554 ACC: 0.6763                                                                                                    
New Best Val loss: 3.8554                                                                                                    
Test loss: 3.8493 ACC: 0.6837  72.67

webkb ####################################################################################
Train loss: 1.2228/1.2037 ACC: 0.9504                                                                                                     
Val loss: 1.3787 ACC: 0.80316                                                                                                    
New Best Val loss: 1.3787                                                                                                    
Test loss: 1.3857 ACC: 0.78858   81.53

"""

In [None]:
a = [terms_idx, terms_idx]
torch.stack(a)

In [None]:
F.softmax(pred_docs_t).argmax(axis=1)

In [None]:
y_val

In [None]:
(y_t == F.softmax(pred_docs_t).argmax(axis=1)).sum().item()/y_t.shape[0]

In [None]:
terms_idx_t, docs_offsets_t

In [None]:
shifts = sc._get_shift_(docs_offsets_t, terms_idx_t.shape[0])

In [None]:
zipado = zip(docs_offsets_t, shifts)
#next(zipado)
start,size = next(zipado)

In [None]:
w = sc.tt_emb( terms_idx_t[start:start+size] )
w1 = sc.tt_dir_map( w )
w = torch.matmul( w, w1.T )
w = F.leaky_relu( w, negative_slope=sc.negative_slope)
w = F.sigmoid(w)
#w = F.tanh(w)
#w = F.relu(w)
#w = w.mean(axis=1)
#w = F.softmax(w)
w,w1

In [None]:
w.round().sum() / (w.shape[0]*w.shape[1])

In [None]:
inv_mapper = { v:k for (k,v) in tokenizer.node_mapper.items() }

In [None]:
terms_idx[start:start+size]

In [None]:
fold.X_val[0]

In [None]:
bla = w.mean(axis=1)
bla = F.softmax(bla)
#bla = bla/torch.clamp(bla.sum(), 0.0001)
bla

In [None]:
[ (i, tid.item(),inv_mapper[tid.item()], wei.item()) for i, (tid, wei) in enumerate(zip(terms_idx[start:start+size], bla)) ]

In [None]:
w

In [None]:
w.mean(axis=1)

In [None]:
w.shape

In [None]:
norm = nn.BatchNorm1d(num_features=1).to(device)

bla2 = norm(bla.view(-1, 1)).squeeze()
bla2 = F.sigmoid(bla2)
bla2

In [None]:
1

In [None]:
ma = nn.MultiheadAttention(300, 300).to(device)
ma

In [None]:
torch.__version__

In [None]:
a,b = w.shape
w_ = w.view(a,1,b)
w1_ = w1.view(a,1,b)

attn_output = ma(w_, w1_, w_, need_weights=False)

attn_output.view(a,b)
attn_output_weights.view(a,a)

In [None]:
attn_output.view(a,b).shape

In [None]:
attn_output_weights.view(a,a)

In [None]:
F.softmax(torch.Tensor([[0.0718, 0.0716, 0.0712, 0.0714, 0.0721, 0.0710, 0.0712, 0.0714, 0.0710,
         0.0719, 0.0711, 0.0712, 0.0718, 0.0714],
        [0.0722, 0.0709, 0.0710, 0.0709, 0.0728, 0.0723, 0.0709, 0.0709, 0.0719,
         0.0712, 0.0709, 0.0707, 0.0725, 0.0707],
        [0.0711, 0.0715, 0.0710, 0.0713, 0.0710, 0.0712, 0.0718, 0.0713, 0.0709,
         0.0725, 0.0716, 0.0719, 0.0710, 0.0719],
        [0.0721, 0.0717, 0.0714, 0.0713, 0.0717, 0.0714, 0.0711, 0.0710, 0.0713,
         0.0717, 0.0710, 0.0712, 0.0720, 0.0711],
        [0.0722, 0.0711, 0.0710, 0.0710, 0.0737, 0.0716, 0.0709, 0.0705, 0.0714,
         0.0719, 0.0707, 0.0707, 0.0731, 0.0702],
        [0.0714, 0.0716, 0.0708, 0.0711, 0.0718, 0.0713, 0.0712, 0.0717, 0.0712,
         0.0724, 0.0713, 0.0712, 0.0716, 0.0714],
        [0.0712, 0.0714, 0.0713, 0.0713, 0.0722, 0.0716, 0.0709, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0713, 0.0716, 0.0713],
        [0.0716, 0.0707, 0.0712, 0.0714, 0.0717, 0.0715, 0.0714, 0.0715, 0.0712,
         0.0721, 0.0711, 0.0714, 0.0717, 0.0715],
        [0.0717, 0.0713, 0.0708, 0.0710, 0.0725, 0.0717, 0.0714, 0.0709, 0.0712,
         0.0712, 0.0712, 0.0717, 0.0720, 0.0714],
        [0.0709, 0.0712, 0.0713, 0.0712, 0.0713, 0.0713, 0.0714, 0.0719, 0.0712,
         0.0724, 0.0715, 0.0715, 0.0717, 0.0713],
        [0.0715, 0.0716, 0.0712, 0.0708, 0.0725, 0.0717, 0.0712, 0.0714, 0.0714,
         0.0717, 0.0713, 0.0706, 0.0718, 0.0712],
        [0.0726, 0.0711, 0.0713, 0.0706, 0.0737, 0.0721, 0.0709, 0.0707, 0.0714,
         0.0708, 0.0706, 0.0705, 0.0730, 0.0707],
        [0.0718, 0.0711, 0.0714, 0.0715, 0.0725, 0.0707, 0.0707, 0.0711, 0.0712,
         0.0728, 0.0711, 0.0713, 0.0719, 0.0709],
        [0.0712, 0.0716, 0.0713, 0.0712, 0.0717, 0.0707, 0.0708, 0.0721, 0.0709,
         0.0728, 0.0713, 0.0716, 0.0711, 0.0715]]).sum(axis=0))

In [None]:
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout1=0.1, dropout2=0.1, negative_slope=99,
                 initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.undirected_map = nn.Linear(hidden_l, hidden_l)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop1 = nn.Dropout(dropout1)
        self.drop2 = nn.Dropout(dropout2)
        
        self.norm = nn.BatchNorm1d(1)
        
        self.initrange = initrange
        self.nclass = nclass
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        
        terms_h1 = self.tt_emb(terms_idxs)
        terms_h1 = self.drop1(terms_h1)
        
        terms_h2 = self.undirected_map( terms_h1 )
        #terms_h2 = self.drop1( terms_h2 )
        for start,size in zip(docs_offsets, shifts):
            w  = terms_h1[start:start+size]
            w1 = terms_h2[start:start+size]
            w = torch.matmul( w, w1.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w-5.5)
            w = w.mean(axis=1)
            w = F.softmax(w)
            #w = w / torch.clamp(w.sum(), 0.0001)
            weights.append( w )
        
        weights = torch.cat(weights)
        #weights = self.norm(weights.view(-1, 1)).squeeze()
        #weights = F.sigmoid(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop2( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])