In [1]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.base import BaseEstimator, TransformerMixin
import networkx as nx
from itertools import repeat

from collections import Counter
from TGA.utils import preprocessor

from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset
#from tqdm import tqdm
from tqdm.notebook import tqdm
from time import time
import numpy as np

In [2]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/acm/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 19907)

In [3]:
class Graphsize(BaseEstimator, TransformerMixin):
    def __init__(self, mindf=2, w=2, stopwords='remove', encoding='utf-8', verbose=False):
        super(Graphsize, self).__init__()
        self.mindf = mindf
        self.w = w
        self.encoding = encoding
        self.le = LabelEncoder()
        if not verbose:
            self.progress_bar = lambda x: x
        else:
            from tqdm import tqdm
            self.progress_bar = tqdm
            
        self.analyzer = TfidfVectorizer(preprocessor=preprocessor)
    
    def fit(self, X, y):
        self.N = len(X)
        y_train = self.le.fit_transform( y )
        self.n_class = len(self.le.classes_)

        self.term_freqs = Counter()
        docs = map(self.analyzer.build_analyzer(), X)
        for doc_in_terms in self.progress_bar(docs, total=self.N):
            self.term_freqs.update(list(set(doc_in_terms)))
        self.node_mapper      = {}
        self.term_freqs       = { term:v for (term,v) in self.term_freqs.items() if v >= self.mindf }    
        self.node_mapper      = { term:self.node_mapper.setdefault(term, len(self.node_mapper)) for term in self.term_freqs.keys() }
        self.node_mapper['<UNK>'] = len(self.node_mapper)
        self.vocab_size = len(self.node_mapper)
        
        return self
   
    def transform(self, text):
        analy = self.analyzer.build_analyzer()
        n = len(text)
        docs = map(analy, text)
        result = list(map(self._build_graph_, self.progress_bar(docs, total=n)))
        return result
    
    def _build_graph_(self, doc):
        terms        = [ term if term in self.node_mapper else '<UNK>' for term in doc ]
        local_mapper = { self.node_mapper[word]:word for word in set(terms) }
        terms_nids   = [ self.node_mapper[word] for word in terms ]
    

        cooccur_count = Counter()
        for i,nid in enumerate(terms_nids):
            terms_to_add = terms_nids[ max(i-self.w, 0):(i+1) ]
            terms_to_add = list(zip(terms_to_add, repeat(nid)))
            terms_to_add = list(map(sorted,terms_to_add))
            terms_to_add = list(map(tuple,terms_to_add))
            cooccur_count.update( terms_to_add )
        
        G = nx.Graph()
        G.add_nodes_from( [ (nid,{'term': word,'idx':nid}) for (nid,word) in local_mapper.items() ] )
        w_edges = [ (s,t) for ((s,t),w) in cooccur_count.items() ]
        G.add_edges_from( w_edges )
        return G

In [4]:
graphisize = Graphsize(w=5, mindf=2, verbose=True)

In [5]:
graphisize.fit(fold.X_train, fold.y_train)
y_train = graphisize.le.transform(fold.y_train)
y_val = graphisize.le.transform(fold.y_val)

100%|██████████| 19907/19907 [00:05<00:00, 3565.47it/s]


In [6]:
gs_train = graphisize.transform( fold.X_train )
gs_val   = graphisize.transform( fold.X_val )

100%|██████████| 19907/19907 [00:22<00:00, 871.09it/s] 
100%|██████████| 2495/2495 [00:02<00:00, 926.67it/s] 


In [7]:
import torch
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
from dgl.nn.pytorch.conv import GraphConv, GATConv
from dgl.nn.pytorch.glob import GlobalAttentionPooling, AvgPooling

from sklearn.preprocessing import LabelEncoder

from itertools import repeat

import torch.optim as optim
from torch.utils.data import DataLoader

Using backend: pytorch


In [8]:
class GenericGAT(nn.Module):
    def __init__(self, vocab_size, hidden_dim, n_class,
                 drop=.5, n_heads=16, attn_drop=.5,
                 activation=F.leaky_relu, n_convs=1, device='cpu:0'):
        super(GenericGAT, self).__init__()
        self.n_hiddens = hidden_dim
        self.device = torch.device(device)
        self.embbedding = nn.Embedding(vocab_size, hidden_dim, scale_grad_by_freq=False).to(self.device)
        
        self.layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim, residual=True, num_heads=n_heads, activation=activation,
                    feat_drop=drop, attn_drop=attn_drop).to(self.device) for _ in range(n_convs)
        ])
        self.down_proj = [
            nn.Linear(n_heads*hidden_dim, hidden_dim).to(self.device) for _ in range(n_convs)
        ]
        
        
        self.lin = nn.Linear(hidden_dim, 1).to(self.device)
        self.pooling1 = GlobalAttentionPooling( self.lin ).to(self.device)
        
        self.pooling2 = AvgPooling()
        
        self.fc = nn.Linear(2*hidden_dim, n_class).to(self.device)
        
    def forward(self, gs):
        with gs.local_scope():
            h = self.embbedding(gs.ndata['idx'])
            #for l, conv in enumerate(self.layers):
            #    h = conv(gs, h)
            #    h = h.view(h.shape[0], -1)
            #    h = self.down_proj[l]( h )
            H1 = self.pooling1(gs, h)
            H2 = self.pooling2(gs, h)
        H = torch.cat([H1, H2], axis=1)
        return self.fc(H)

In [9]:
def collate_train(param):
    Gs_nx, y = zip(*param)
    Gs_dgl_list = []
    for g in Gs_nx:
        g_dgl = dgl.DGLGraph()
        if len(g) > 0:
            g_dgl.from_networkx(g, node_attrs=['idx'] )
        Gs_dgl_list.append( g_dgl )

    return dgl.batch(Gs_dgl_list), torch.tensor(y)

In [10]:
nepochs = 2000
max_epochs = 10
drop=0.1
device = torch.device('cuda:0')
batch_size = 128

In [11]:
gat = GenericGAT( graphisize.vocab_size, 300, graphisize.n_class, drop=drop, attn_drop=drop, device=device ).to(device)

optimizer = optim.AdamW( gat.parameters(), lr=5e-3, weight_decay=5e-4)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

In [12]:
best = 0.
counter = 1
for e in tqdm(range(nepochs), total=nepochs):
    dl_train = DataLoader(list(zip(gs_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate_train, num_workers=3)
    dl_val = DataLoader(list(zip(gs_val, y_val)), batch_size=batch_size,
                             shuffle=False, collate_fn=collate_train, num_workers=3)
    total_loss  = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0., desc=f"Epoch {e+1}") as pbar:
        total = 0
        correct  = 0
        gat.train()
        for i, (gs, y) in enumerate(dl_train):
            gs = gs.to( device )
            y  = y.to( device )
            
            pred_docs = gat( gs )
            pred_docs = F.softmax(pred_docs)
            loss = loss_func_cel(pred_docs, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total      += len(y)
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            
            toprint  = f"Train loss: {total_loss/(i+1):.4}/{loss.item():.4} "
            toprint += f'ACC: {correct/total:.4}'
            
            print(toprint, end=f"{' '*100}\r")
            
            pbar.update( len(y) )
            
        scheduler.step()
        total = 0
        correct  = 0
        gat.eval()
        print()
        for i, (gs, y) in enumerate(dl_val):
            gs = gs.to( device )
            y  = y.to( device )
            
            pred_docs = gat( gs )
            pred_docs = F.softmax(pred_docs)
            
            y_pred      = pred_docs.argmax(axis=1)
            correct    += (y_pred == y).sum().item()
            total      += len(y)
            
            print(f'Val ACC: {correct/total:.4}/{best:.4}', end=f"{' '*100}\r")
            
            pbar.update( len(y) )
        if (correct/total) > best:
            best = (correct/total)
            counter = 1
            print(f'New Best Val ACC: {best:.4}')
        elif counter > max_epochs:
            print()
            print(f'Best Val ACC: {best:.4}')
            break
        else:
            counter += 1

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 2.086/1.961 ACC: 0.4708                                                                                                                                                                                                        
New Best Val ACC: 0.612                                                                                                



HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.854/1.835 ACC: 0.6984                                                                                                    
New Best Val ACC: 0.677                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.766/1.86 ACC: 0.7835                                                                                                     
New Best Val ACC: 0.6962                                                                                                 



HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.722/1.704 ACC: 0.8263                                                                                                    
New Best Val ACC: 0.7126                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.693/1.738 ACC: 0.855                                                                                                     
New Best Val ACC: 0.7186                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.67/1.636 ACC: 0.8772                                                                                                                                                                                                         
New Best Val ACC: 0.721                                                                                                   



HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.656/1.71 ACC: 0.8909                                                                                                                                                                                                         
Val ACC: 0.7194/0.721                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.644/1.596 ACC: 0.9027                                                                                                    
New Best Val ACC: 0.7251                                                                                                 



HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=22402.0, style=ProgressStyle(description_wi…

Train loss: 1.636/1.609 ACC: 0.9094                                                                                                    
Val ACC: 0.7194/0.7251                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.63/1.632 ACC: 0.9143                                                                                                                                                                                                         
Val ACC: 0.7222/0.7251                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.626/1.606 ACC: 0.9183                                                                                                    
Val ACC: 0.7226/0.7251                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.622/1.638 ACC: 0.9218                                                                                                                                                                                                        
Val ACC: 0.7238/0.7251                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.619/1.612 ACC: 0.925                                                                                                                                                                                                         
New Best Val ACC: 0.7255                                                                                                  



HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.617/1.62 ACC: 0.9272                                                                                                                                                                                                         
Val ACC: 0.7234/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.615/1.619 ACC: 0.9289                                                                                                                                                                                                        
Val ACC: 0.7238/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.613/1.617 ACC: 0.9303                                                                                                                                                                                                        
Val ACC: 0.7238/0.7255                                                                                                                                                                                                        


HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.612/1.663 ACC: 0.9316                                                                                                    
Val ACC: 0.7246/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.611/1.589 ACC: 0.9325                                                                                                    
Val ACC: 0.7255/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.61/1.633 ACC: 0.9333                                                                                                                                                                                                         
Val ACC: 0.7238/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.609/1.63 ACC: 0.934                                                                                                      
Val ACC: 0.7255/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.608/1.579 ACC: 0.9345                                                                                                                                                                                                        
Val ACC: 0.7226/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.608/1.602 ACC: 0.9354                                                                                                                                                                                                        
Val ACC: 0.7222/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.607/1.588 ACC: 0.9361                                                                                                                                                                                                        
Val ACC: 0.7234/0.7255                                                                                                    


HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=22402.0, style=ProgressStyle(description_w…

Train loss: 1.607/1.603 ACC: 0.9365                                                                                                                                                                                                                                                                                                            
Val ACC: 0.7234/0.7255                                                                                                    
Best Val ACC: 0.7255


