In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset

from tqdm.notebook import tqdm
from TGA.utils import preprocessor

from time import time
import numpy as np
from itertools import repeat
from collections import Counter

from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction import stop_words
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [3]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/acm/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 19907)

In [4]:
class Tokenizer(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, stopwords='remove', lan='english', verbose=False):
        super(Tokenizer, self).__init__()
        self.mindf = mindf
        self.le = LabelEncoder()
        self.verbose = verbose
        self.stopwords = stopwords
        self.stopwordsSet = stop_words.ENGLISH_STOP_WORDS
        self.lan = lan
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor)
    
    def fit(self, X, y):
        self.N = len(X)
        self.le.fit( y )
        self.n_class = len(self.le.classes_)

        self.term_freqs = Counter()
        analyzer = self.analyzer.build_analyzer()
        docs = map(analyzer, X)
        for doc_in_terms in tqdm(docs, total=self.N, disable=not self.verbose):
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self._get_idx_(term) for term in self.term_freqs.keys() if self._isrel_(term) }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.vocab_size = len(self.node_mapper)
        
        return self
    def _isrel_(self, term):
        if self.stopwords == 'remove' and term in self.stopwordsSet:
            return False
        # put here your filter_functions
        return True
    def _get_idx_(self, term):
        # put here your idx_set_functions
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return self.node_mapper.setdefault('<STPW>', len(self.node_mapper))
        return self.node_mapper.setdefault(term, len(self.node_mapper))
    def _filter_transform_(self, term):
        if self.stopwords == 'mark' and term in self.stopwordsSet:
            return '<STPW>'
        return term
    def transform(self, X, verbose=None):
        verbose = verbose if verbose is not None else self.verbose
        n = len(X)
        analyzer = self.analyzer.build_analyzer()
        doc_off = [0]
        terms_idx = []
        for i,doc_in_terms in tqdm(enumerate(map(analyzer, X)), total=n, disable=not verbose):
            doc_in_terms = set(map( self._filter_transform_, doc_in_terms ))
            doc_in_terms = filter( lambda x: x in self.node_mapper, doc_in_terms )
            doc_in_terms = [ self.node_mapper[tid] for tid in doc_in_terms ]
            doc_off.append( len(doc_in_terms) )
            terms_idx.extend( doc_in_terms )
        return np.array( terms_idx ), np.array(doc_off)[:-1].cumsum()

In [5]:
tokenizer = Tokenizer(stopwords='mark', verbose=True)
tokenizer.fit(fold.X_train, fold.y_train)

HBox(children=(FloatProgress(value=0.0, max=19907.0), HTML(value='')))




Tokenizer(stopwords='mark', verbose=True)

In [6]:
y_train = tokenizer.le.transform( fold.y_train )
y_val   = tokenizer.le.transform( fold.y_val )

In [7]:
def collate_train(param):
    X, y = zip(*param)
    terms_ids, docs_offsets = tokenizer.transform(X, verbose=False)
    return torch.LongTensor(terms_ids), torch.LongTensor(docs_offsets), torch.LongTensor(y)

In [12]:
class SimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout=0.1, initrange = 0.5, device='cuda:0'):
        super(SimpleClassifier, self).__init__()
        
        self.doc_terms_emb = nn.EmbeddingBag(vocab_size, hidden_l, mode='mean', scale_grad_by_freq=False)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop = nn.Dropout(dropout)
        
        self.initrange = initrange
        self.nclass = nclass
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        h_docs = self.doc_terms_emb( terms_idxs, docs_offsets )
        h_docs = self.drop( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.doc_terms_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
class NotTooSimpleClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_l, nclass, dropout=0.1, negative_slope=99, mode='mean', initrange = 0.5, scale_grad_by_freq=False, device='cuda:0'):
        super(NotTooSimpleClassifier, self).__init__()
        
        self.dt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        self.tt_emb = nn.Embedding(vocab_size, hidden_l, scale_grad_by_freq=scale_grad_by_freq)
        
        self.fc = nn.Linear(hidden_l, nclass)
        self.drop = nn.Dropout(dropout)
        
        self.initrange = initrange
        self.nclass = nclass
        self.mode = mode
        self.negative_slope = negative_slope
        
        self.init_weights()
        
        #self.labls_emb = nn.Embedding(graph_builder.n_class, 300)
    
    def forward(self, terms_idxs, docs_offsets):
        n = terms_idxs.shape[0]
        weights = []
        shifts = self._get_shift_(docs_offsets, n)
        for start,size in zip(docs_offsets, shifts):
            w = self.tt_emb( terms_idx[start:start+size] )
            w = self.drop( w )
            w = torch.matmul( w, w.T )
            w = F.leaky_relu( w, negative_slope=self.negative_slope)
            w = F.sigmoid(w)
            w = F.softmax(w.mean(axis=1))
            weights.append( w )
        
        weights = torch.cat(weights)
        
        h_docs  = F.embedding_bag(self.dt_emb.weight, terms_idxs, docs_offsets, per_sample_weights=weights, mode='sum')
        h_docs = self.drop( h_docs )
        pred_docs = self.fc( h_docs )
        return pred_docs

    def init_weights(self):
        self.dt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        self.tt_emb.weight.data.uniform_(-self.initrange, self.initrange)
        
    def _get_shift_(self, offsets, lenght):
        shifts = offsets[1:] - offsets[:-1]
        last = torch.LongTensor([lenght - offsets[-1]]).to( offsets.device )
        return torch.cat([shifts, last])

In [13]:
nepochs = 2000
max_epochs = 10
drop=0.5
device = torch.device('cuda:0')
batch_size = 16

In [14]:
#sc = SimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, dropout=drop).to( device )
sc = NotTooSimpleClassifier(tokenizer.vocab_size, 300, tokenizer.n_class, negative_slope=1000, dropout=drop).to( device )

optimizer = optim.AdamW( sc.parameters(), lr=5e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

In [15]:
best = 0.
counter = 1
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=5)
    dl_val = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                             shuffle=False, collate_fn=collate_train, num_workers=5)
    total_loss  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        sc.train()
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_train):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs = sc( terms_idx, docs_offsets)
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            
            toprint  = f"Train loss: {total_loss/(i+1):.4}/{loss.item():.4} "
            toprint += f'ACC: {correct/total:.4}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            
        scheduler.step()
        total = 0
        correct  = 0
        sc.eval()
        print()
        for i, (terms_idx, docs_offsets, y) in enumerate(dl_val):
            terms_idx    = terms_idx.to( device )
            docs_offsets = docs_offsets.to( device )
            y            = y.to( device )
            
            pred_docs = sc( terms_idx, docs_offsets )
            pred_docs = F.softmax(pred_docs)
            
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            total      += len(y)
            
            print(f'Val ACC: {correct/total:.4}/{best:.4}', end=f"{' '*100}\r")
            
            pbar.update( len(y) )
        if (correct/total) > best:
            best = (correct/total)
            counter = 1
            print(f'New Best Val ACC: {best:.4}')
        elif counter > max_epochs:
            print()
            print(f'Best Val ACC: {best:.4}')
            break
        else:
            counter += 1

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.929/1.875 ACC: 0.6486                                                                                                     
New Best Val ACC: 0.7523                                                                                               



HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.711/1.543 ACC: 0.8406                                                                                                    
New Best Val ACC: 0.7619                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.659/1.566 ACC: 0.8881                                                                                                    
New Best Val ACC: 0.7627                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.635/1.545 ACC: 0.9118                                                                                                    
Val ACC: 0.7527/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.621/1.548 ACC: 0.9237                                                                                                    
Val ACC: 0.7579/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.611/1.543 ACC: 0.9332                                                                                                    
Val ACC: 0.7559/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.604/1.545 ACC: 0.9404                                                                                                    
Val ACC: 0.7547/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.598/1.543 ACC: 0.9456                                                                                                    
Val ACC: 0.7543/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.593/1.543 ACC: 0.951                                                                                                     
Val ACC: 0.7571/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.589/1.549 ACC: 0.9545                                                                                                    
Val ACC: 0.7535/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.587/1.543 ACC: 0.9568                                                                                                    
Val ACC: 0.7555/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.585/1.543 ACC: 0.9588                                                                                                    
Val ACC: 0.7587/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.583/1.543 ACC: 0.9601                                                                                                    
Val ACC: 0.7583/0.7627                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.582/1.543 ACC: 0.9616                                                                                                    
Val ACC: 0.7583/0.7627                                                                                                    
Best Val ACC: 0.7627




In [12]:
# acm stopword=mark
# Best Val ACC: 0.7784


# acm stopword=remove
# Best Val ACC: 0.7776

# acm stopword=keep
# Best Val ACC: 0.7675


In [13]:
terms_idx.shape, docs_offsets.shape

(torch.Size([627]), torch.Size([7]))

In [14]:
W = sc.tt_emb.weight
W.shape

torch.Size([22643, 300])

In [15]:
W = torch.matmul(w, w.T)
#W = F.leaky_relu(W, negative_slope=10)

NameError: name 'w' is not defined

In [None]:
W.shape

In [None]:
terms_idx[docs_offsets[i]:docs_offsets[i]+shifts[i]]

In [None]:
sum(shifts)

In [None]:
i=-1
h_terms = emb( terms_idx[docs_offsets[i]:docs_offsets[i]+shifts[i]] )
h_terms.shape

In [None]:
docs_offsets[i],docs_offsets[i]+shifts[i], shifts[i]

In [None]:
terms_idx.shape

In [None]:
self_coor = torch.matmul( h_terms, h_terms.T )
self_coor = F.leaky_relu(self_coor, negative_slope=100000)
self_coor = F.sigmoid(self_coor)
self_coor = F.softmax(self_coor.mean(axis=1))
self_coor

In [None]:
F.embedding_bag(emb.weight, terms_idx[], docs_offsets, mode='mean').shape

In [None]:
torch.stack([ torch.LongTensor([1,2,3]), torch.LongTensor([4,5,6]) ])