In [1]:
import warnings
warnings.filterwarnings('ignore')

from TGA.utils import Dataset, GraphsizePretrained
#from tqdm import tqdm
from tqdm.notebook import tqdm
from time import time
import numpy as np

from sklearn.preprocessing import LabelEncoder

In [2]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
%%time
graph_builder = GraphsizePretrained(w=2, verbose=True,
                   pretrained_vec='/home/Documentos/Universidade/LBD/pretrained_vectors/glove/glove.6B.300d.txt')

400000it [00:25, 15386.38it/s]


CPU times: user 25.6 s, sys: 685 ms, total: 26.3 s
Wall time: 26.2 s


In [4]:
dataset = Dataset('/home/Documentos/datasets/classification/datasets/webkb/')
fold = next(dataset.get_fold_instances(10, with_val=True))
fold._fields, len(fold.X_train)

le = LabelEncoder()
y_train = le.fit_transform(fold.y_train)
y_val = le.transform(fold.y_val)

In [5]:
import torch
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
from dgl.nn.pytorch.conv import GraphConv, GATConv
from dgl.nn.pytorch.glob import GlobalAttentionPooling

from sklearn.preprocessing import LabelEncoder

from itertools import repeat

import torch.optim as optim
from torch.utils.data import DataLoader

Using backend: pytorch


In [6]:
%%time
graph_builder.fit(fold.X_train, fold.y_train)

100%|██████████| 6553/6553 [00:05<00:00, 1164.76it/s]


CPU times: user 6.96 s, sys: 95.9 ms, total: 7.06 s
Wall time: 7.04 s


GraphsizePretrained(encoding=None,
                    pretrained_vec='/home/Documentos/Universidade/LBD/pretrained_vectors/glove/glove.6B.300d.txt',
                    verbose=None)

In [7]:
len(graph_builder.g.edges), len(graph_builder.g)

(88285, 20749)

In [8]:
list(map(lambda x: (x,graph_builder.g.degree()[x]), graph_builder.label_ids))

[(0, 11585),
 (1, 19771),
 (2, 8994),
 (3, 3714),
 (4, 8478),
 (5, 12273),
 (6, 2735)]

In [21]:
class GenericGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim,
                 drop=.5, n_heads=8, attn_drop=.5,
                 activation=F.leaky_relu, n_convs=2,
                 first_hidden='emb', encoders={'term','label'},
                 device='cpu:0'):
        super(GenericGAT, self).__init__()
        self.n_hiddens = hidden_dim
        self.device = torch.device(device)
        self.first_hidden = first_hidden
        
        self.encoders = nn.ModuleDict({
            k: nn.Linear(in_dim, hidden_dim).to(self.device) for k in encoders
        })
        #self.norm = nn.BatchNorm1d(hidden_dim).to(self.device)
        
        self.layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim, residual=True, num_heads=n_heads, activation=activation,
                    feat_drop=drop, attn_drop=attn_drop).to(self.device) for _ in range(n_convs)
        ])
        self.down_proj = [
            nn.Linear(n_heads*hidden_dim, hidden_dim).to(self.device) for _ in range(n_convs)
        ]
        self.norm_projs = [
            nn.BatchNorm1d(hidden_dim).to(self.device) for _ in range(n_convs)
        ]
        
    def forward(self, G, **kwargs):
        with G.local_scope():
            h = G.ndata[self.first_hidden].float()
            H = torch.zeros( h.shape[0], self.n_hiddens )
            for (k, mask) in kwargs.items():
                if k in self.encoders:
                    if mask is not None:
                        bla = self.encoders[k]( h[ mask ] )
                        h[ mask ] = bla
                    else:
                        h = self.encoders[k]( h )

            for l, conv in enumerate(self.layers):
                h = conv(G, h)
                h = h.view(h.shape[0], -1)
                h = self.down_proj[l]( h )
                h = self.norm_projs[l]( h )
        return h
        

In [22]:

class ClassifierGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, n_heads=16, drop=.5, attn_drop=.5, device='cuda:0'):
        super(ClassifierGAT, self).__init__()

        self.encoder = nn.Linear(in_dim, hidden_dim).to(torch.device(device))
        
        self.layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim, num_heads=n_heads, activation=F.leaky_relu,
                    feat_drop=drop, attn_drop=attn_drop).to(torch.device(device)),
            GATConv(n_heads*hidden_dim, hidden_dim, num_heads=n_heads, activation=F.leaky_relu,
                    feat_drop=drop, attn_drop=attn_drop).to(torch.device(device))
        ])
        
        self.lin = nn.Linear(n_heads*hidden_dim + hidden_dim, 1).to(torch.device(device))
        self.pooling = GlobalAttentionPooling( self.lin ).to(torch.device(device))
        
        self.norm = nn.BatchNorm1d( n_heads*hidden_dim + hidden_dim )
        self.drop = nn.Dropout(drop)
        
        self.classify = nn.Linear( n_heads*hidden_dim + hidden_dim, n_classes).to(torch.device(device))

    def forward(self, G):
        h = G.ndata['emb'].float()
        he = self.encoder(h)
        h = he
        for conv in self.layers:
            h = conv(G, h)
            h = h.view(h.shape[0], -1)
        
        # CONCAT he E hg
        hg = torch.cat((h,he), 1)
        hg = self.norm( hg )
        hg = self.drop( hg )
        hg = self.pooling(G, hg)
        
        pred = self.classify( hg )
        return pred

In [23]:
def collate(param):
    X, y = zip(*param)
    Gs_nx = graph_builder.transform(X)
    
    Gs_dgl_list = []
    for g in Gs_nx:
        g_dgl = dgl.DGLGraph()
        if len(g) > 0:
            g_dgl.from_networkx(g, node_attrs=['emb', 'idx'] )
        Gs_dgl_list.append( g_dgl )
    
    Gs_dgl = dgl.batch(Gs_dgl_list)
    
    big_graph_dgl = dgl.DGLGraph()
    big_graph_dgl.from_networkx(graph_builder.g, node_attrs=['emb', 'label', 'idx'] )
    
    #subgraph = graph_builder.g.subgraph(idx_terms)
    #big_graph_dgl.from_networkx(subgraph, node_attrs=['emb', 'label', 'idx'] )
    
    return big_graph_dgl, Gs_dgl, torch.tensor(y)

In [24]:
class TGA(torch.nn.Module):
    def __init__(self, input_l, hidden_l, nclass, n_heads=1,
                drop=0.5, attn_drop=0.5, loss=None, n_convs=1,activation=None,
                 device='cuda:0'):
        
        super(TGA, self).__init__()
        
        self.gat_global = GenericGAT( input_l, hidden_l, n_heads=n_heads,
                 drop=drop, attn_drop=attn_drop, n_convs=n_convs,
                 activation=activation, device=device ).to(device)
        
        
        self.gat_local = GenericGAT( input_l, hidden_l, n_heads=n_heads,
                 drop=drop, attn_drop=attn_drop, n_convs=n_convs, encoders={'terms'},
                 activation=activation, device=device ).to(device)
        
        self.norm_label = nn.BatchNorm1d(hidden_l).to(device)
        #self.norm_docs = nn.BatchNorm1d(hidden_l).to(device)

        #self.gate = nn.Linear( input_l+hidden_l, 1 ).to(device)
        self.gate = nn.Sequential(
          nn.Linear( input_l+hidden_l, hidden_l ),
          nn.ReLU(),
          nn.Linear( hidden_l, 1 )
        ).to(device)
        self.feat = nn.Linear( input_l+hidden_l, hidden_l ).to(device)
        self.gap  = GlobalAttentionPooling(self.gate, feat_nn=self.feat).to(device)
        
        #self.nclass  = nclass
        #self.fc1     = nn.Linear( hidden_l, hidden_l ).to(device)
        #self.fc2     = nn.Linear(  hidden_l//2, self.nclass ).to(device)
        #self.softmax = nn.Softmax(dim=1)
        
        self.fc_global = nn.Sequential(
          nn.Linear( hidden_l, hidden_l ),
          nn.Tanh(),
          nn.Linear( hidden_l, hidden_l )
        )
        
        self.fc_local = nn.Sequential(
          nn.Linear( hidden_l, hidden_l ),
          nn.ReLU(),
          nn.Linear( hidden_l, hidden_l )
        )
        
        self.fc_local_classifier = nn.Sequential(
          nn.Linear( hidden_l, hidden_l ),
          nn.ReLU(),
          nn.Linear( hidden_l, nclass )
        )
        
        self.loss = loss

    def forward(self, G, gs, label_idx=None):
        if label_idx is None:
            label_idx = G.ndata['label'].nonzero().flatten()
            
        terms_idx = range(len(label_idx),len(graph_builder.g))
        
        h_global  = self.gat_global(G, label=label_idx, term=terms_idx)

        h_labels  = h_global[label_idx]
        #h_labels  = self.norm_label(h_labels)
        h_labels  = self.fc_global(h_labels)

        gs.ndata['emb'] = h_global[gs.ndata['idx'].reshape(-1)]
        h_local         = self.gat_local(gs, terms=None)
        
        terms_to_pred   = G.ndata['emb'][gs.ndata['idx'].reshape(-1)].float()
        
        h_docs          = self.gap( gs, torch.cat((h_local, terms_to_pred), 1 ) )
        h_docs          = self.fc_local(h_docs)
        #h_docs          = self.norm_docs(h_docs)
        pred_docs       = self.fc_local_classifier(h_docs)
        
        #h_docs_pred = self.fc1(h_docs)
        #h_docs_pred = self.fc2(h_docs_pred)
        #h_docs_pred = nn.softmax(h_docs_pred, 1)
        
        return h_docs, pred_docs, h_labels

In [25]:
hidden_l = 64
input_l = 300
n_heads = 4
drop=0.3
batch_size=64
attn_drop=0.3
device=torch.device('cuda:0')

In [26]:
torch.cat( (torch.Tensor([[1,2,3],[4,5,6]]),torch.Tensor([[4,5,6],[7,8,9]])), 1 )

tensor([[1., 2., 3., 4., 5., 6.],
        [4., 5., 6., 7., 8., 9.]])

In [27]:
tga = TGA(input_l, hidden_l, nclass=graph_builder.n_class,
          activation=None,
          n_heads=n_heads, drop=drop, attn_drop=attn_drop, n_convs=2).to(device)
tga

TGA(
  (gat_global): GenericGAT(
    (encoders): ModuleDict(
      (label): Linear(in_features=300, out_features=64, bias=True)
      (term): Linear(in_features=300, out_features=64, bias=True)
    )
    (layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=64, out_features=256, bias=False)
        (feat_drop): Dropout(p=0.3, inplace=False)
        (attn_drop): Dropout(p=0.3, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (res_fc): Identity()
      )
      (1): GATConv(
        (fc): Linear(in_features=64, out_features=256, bias=False)
        (feat_drop): Dropout(p=0.3, inplace=False)
        (attn_drop): Dropout(p=0.3, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (res_fc): Identity()
      )
    )
  )
  (gat_local): GenericGAT(
    (encoders): ModuleDict(
      (terms): Linear(in_features=300, out_features=64, bias=True)
    )
    (layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=6

In [28]:
from TGA.lossweight import cross_entropy

In [29]:
from TGA.lossweight import cross_entropy
class NpairLoss(nn.Module):
    """the multi-class n-pair loss"""
    def __init__(self, l2_reg=0.02):
        super(NpairLoss, self).__init__()
        self.l2_reg = l2_reg

    def forward(self, anchor, target=None, positive=None):
        batch_size = anchor.size(0)
        
        if target is not None:
            target = target.view(target.size(0), 1)
            target = (target == torch.transpose(target, 0, 1)).float()
            target = target / torch.sum(target, dim=1, keepdim=True).float()
        else:
            target = torch.eye(batch_size).to(anchor.device)

        if positive is not None:
            logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))
            l2_loss = torch.sum(anchor**2) / batch_size + torch.sum(positive**2) / batch_size
        else:
            logit = torch.matmul(anchor, torch.transpose(anchor, 0, 1))
            l2_loss = torch.sum(anchor**2) / batch_size
        
        loss_ce = cross_entropy(logit, target)

        loss = loss_ce + self.l2_reg*l2_loss*0.25
        return loss.float()
class SelfDistLoss(nn.Module):
    def __init__(self, l2_reg=0.02, eps = 0.00003, margin=-0.3):
        super(SelfDistLoss, self).__init__()
        self.l2_reg = l2_reg
        self.eps = eps
        self.margin = margin
        
    def forward(self, hiddens):
        L = torch.matmul(hiddens, hiddens.T)
        L = F.sigmoid(L)
        L = (L - (L.diag()-self.margin)).float()
        L = F.relu(L)
        #L = torch.exp( L )
        L = ( L > 0. ).float() * torch.exp( L )
        #L = F.normalize(L)
        
        values = L.sum(axis=1)
        svalue = max((values > 0.).sum(), self.eps)

        return values.sum()/svalue # AVG of non-zero values
    
class SelfDistLoss2(nn.Module):
    def __init__(self, l2_reg=0.02, eps = 0.00003):
        super(SelfDistLoss, self).__init__()
        self.l2_reg = l2_reg
        self.eps = eps
        
    def forward(self, hiddens):
        L = torch.matmul(hiddens, hiddens.T)
        L = (L - L.diag()).float()
        L = F.sigmoid(L)
        #L = F.relu(L)
        #L = torch.exp( L )
        L = torch.exp( L )
        #L = F.normalize(L)

        return L.mean(axis=1).mean()/np.e # AVG of non-zero values

In [30]:

optimizer = optim.AdamW( tga.parameters(), lr=5e-3, weight_decay=5e-4)

loss_func_npl = NpairLoss(l2_reg=5e-4)
loss_func_cel = nn.CrossEntropyLoss()
loss_func_spc = SelfDistLoss()
loss_func_self_npl = NpairLoss(l2_reg=5e-4)

#RMSprop

In [31]:
torch.eye(10).device

device(type='cpu')

In [32]:

best = None
nepochs = 25
for e in tqdm(range(nepochs), total=nepochs):
    epoch_loss = 0
    data_loader = DataLoader(list(zip(fold.X_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate, num_workers=3)
    total_loss  = 0.
    total_loss1 = 0.
    total_loss2 = 0.
    total_loss3 = 0.
    with tqdm(total=len(y_train)+len(y_val), smoothing=0.) as pbar:
        total = 1
        correct_class = 0
        correct_repre = 0
        correct_both  = 0
        tga.train()
        for i, (G, gs, y) in enumerate(data_loader):
            G = G.to( device )
            gs = gs.to( device )
            y = y.to( device )
            
            h_docs, pred_docs, h_labels = tga( G, gs )
            #h_docs, pred_docs = tga( G, gs, y )
            
            
            pred_docs = F.softmax(pred_docs)
            pred_docs2 = F.softmax(torch.matmul(h_docs, h_labels.T))
            pred_docs3 = pred_docs+pred_docs2
            
            loss1 = loss_func_npl( h_docs, y, positive=h_labels[y] )
            loss2 = loss_func_cel(pred_docs, y)
            loss3 = loss_func_self_npl(h_labels)
            #loss3 = loss_func_spc(h_labels)
            
            loss = loss1 + loss2 + loss3
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total      += len(y)
            
            y_pred            = pred_docs.argmax(axis=1)
            correct_repre    += (y_pred == y).sum()
            total_loss1      += loss1.item()
            
            y_pred = pred_docs2.argmax(axis=1)
            correct_class    += (y_pred == y).sum()
            total_loss2      += loss2.item()
            
            y_pred = pred_docs3.argmax(axis=1)
            correct_both     += (y_pred == y).sum()
            total_loss3      += loss3.item()
            
            to_print   = f'(L)oss: {loss.item():.3}/{total_loss /(i+1):.3} '
            to_print  += f'NP-L: {loss1.item():.3}/{total_loss1/(i+1):.3} '
            to_print  += f'CE-L: {loss2.item():.3}/{total_loss2/(i+1):.3} '
            to_print  += f'Sp-L: {loss3.item():.3}/{total_loss3/(i+1):.3} '
            to_print  += f'Acc Cls: {(1.*correct_class/total).item():.3} '
            to_print  += f'Repr: {(1.*correct_repre/total).item():.3} '
            to_print  += f'Both: {(1.*correct_both/total).item():.3}'
            pbar.update( len(y) )
            pbar.set_description_str(f'iter {e} Acc: {(1.*correct_both/total).item():.3}')
            
            #break
            if best is None or best > (total_loss/(i+1)):
                hiddens_labels = h_labels
                hiddens_docs = h_docs
                best = total_loss/(i+1)
            #del loss, h_labels, G, gs, loss1, loss2, pred_docs, h_docs
            print(to_print, end=f"{' '*100}\r")
            del G, gs, y
            del h_docs, pred_docs, h_labels
            del  pred_docs2, pred_docs3
            del loss1, loss2, loss3
            #break
        to_print   = f'(L)oss: {total_loss /(i+1):.3} '
        to_print  += f'NP-L: {total_loss1/(i+1):.3} '
        to_print  += f'CE-L: {total_loss2/(i+1):.3} '
        to_print  += f'Sp-L: {total_loss3/(i+1):.3} '
        to_print  += f'Acc Cls: {(1.*correct_class/total).item():.3} '
        to_print  += f'Repr: {(1.*correct_repre/total).item():.3} '
        to_print  += f'Both: {(1.*correct_both/total).item():.3}'
        print(to_print, end=f"{' '*100}\n")
        
        ############################################# EVALUATION #############################################
        data_loader = DataLoader(list(zip(fold.X_val, y_val)), batch_size=batch_size,
                                 shuffle=False, collate_fn=collate, num_workers=3)
        total = 0
        correct_class = 0
        correct_repre = 0
        correct_both  = 0
        tga.eval()
        for i, (G, gs, y) in enumerate(data_loader):
            G = G.to( device )
            gs = gs.to( device )
            y = y.to( device )
            
            h_docs, pred_docs, h_labels = tga( G, gs )
            
            pred_docs = F.softmax(pred_docs)
            pred_docs2 = F.softmax(torch.matmul(h_docs, h_labels.T))
            pred_docs3 = pred_docs+pred_docs2
            
            total      += len(y)
            
            y_pred            = pred_docs.argmax(axis=1)
            correct_repre    += (y_pred == y).sum()
            
            y_pred = pred_docs2.argmax(axis=1)
            correct_class    += (y_pred == y).sum()
            
            y_pred = pred_docs3.argmax(axis=1)
            correct_both     += (y_pred == y).sum()
            
            pbar.update( len(y) )
            pbar.set_description_str(f'iter {e} Acc: {(1.*correct_both/total).item():.3}')
            
            del G, gs, y
            del h_docs, pred_docs, h_labels
            del pred_docs2, pred_docs3
        to_print   = f'EVAL:  Acc Cls: {(1.*correct_class/total).item():.3} '
        to_print  += f'Repr: {(1.*correct_repre/total).item():.3} '
        to_print  += f'Both: {(1.*correct_both/total).item():.3}'
        print(to_print, end=f"{' '*100}\r")
    del data_loader

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=7376.0), HTML(value='')))





RuntimeError: shape mismatch: value tensor of shape [7, 64] cannot be broadcast to indexing result of shape [7, 300]

In [None]:
# ACM
# (L)oss: 6.05 NP-L: 4.41 CE-L: 1.64 Sp-L: 0.0 Acc Cls: 0.922 Repr: 0.907 Both: 0.932                                                                                                                       
# ALL: [ *77.84, *79.92, 70.70, 66.85, 67.33, 50.46, 76.20, 74.10, 74.20, 76.73, 73.10, 76.60, *78.08, 76.99]

#reut
# (L)oss: 9.44 NP-L: 3.34 CE-L: 3.88 Sp-L: 2.22 Acc Cls: 0.732 Repr: 0.639 Both: 0.662                                                                                                                        

# 20ng
# (L)oss: 4.41 NP-L: 2.2 CE-L: 2.18 Sp-L: 0.0226 Acc Cls: 0.917 Repr: 0.896 Both: 0.915                

# WebKB
# (L)oss: 4.94 NP-L: 3.64 CE-L: 1.31 Sp-L: 0.0 Acc Cls: 0.867 Repr: 0.858 Both: 0.865                                                                                                                       



In [None]:
# webkb (MaxF1=85.54)
# exp     (b=128) webkb (L)oss: 5.53 NP-L: 3.68 CE-L: 1.34 Sp-L: 0.521 Acc Cls: 0.847 Repr: 0.831 Both: 0.837                                                                                                                        
# exp     (b=16)  webkb (L)oss: 4.39 NP-L: 1.90 CE-L: 1.38 Sp-L: 1.110 Acc Cls: 0.743 Repr: 0.782 Both: 0.784                                                                                                                         
# (b=512, nh=1)   webkb (L)oss: 8.96 NP-L: 5.00 CE-L: 1.33 Sp-L: 2.620 Acc Cls: 0.858 Repr: 0.833 Both: 0.843                                                                                                                         


# ACM (MaxF1=79.92)
# exp           acm   (L)oss: 4.94 NP-L: 3.17 CE-L: 1.69 Sp-L: 0.085 Acc Cls: 0.870 Repr: 0.856 Both: 0.885    
# exp           acm   (L)oss: 4.88 NP-L: 3.05 CE-L: 1.61 Sp-L: 0.220 Acc Cls: 0.915 Repr: 0.929 Both: 0.931    
# (b=512, nh=1) acm   (L)oss: 6.15 NP-L: 4.51 CE-L: 1.64 Sp-L: 0.000 Acc Cls: 0.882 Repr: 0.902 Both: 0.904    
# (b=512, nh=2) acm   

# reut (MaxF1=79.11)
# exp           reut (L)oss: 8.15 NP-L: 2.80 CE-L: 3.87 Sp-L: 1.48 Acc Cls: 0.665 Repr: 0.651 Both: 0.658
# (b=512, nh=1) reut (L)oss: 10.9 NP-L: 4.09 CE-L: 3.92 Sp-L: 2.90 Acc Cls: 0.673 Repr: 0.600 Both: 0.620                                                                                                                            
# (b=512, nh=2) reut (L)oss: 11.2 NP-L: 4.06 CE-L: 3.87 Sp-L: 3.32 Acc Cls: 0.690 Repr: 0.650 Both: 0.668                                                                                                                          




In [None]:
hiddens = h_labels


L = torch.matmul(hiddens, hiddens.T)
L = F.sigmoid(L)
L_mapper = (L >= L.diag()).float()
ind = torch.eye(L_mapper.size(0),L_mapper.size(1)).to(L_mapper.device)
L_mapper *= (1-ind)
L_mapper = F.normalize(L_mapper)
values = (L_mapper * L).sum(axis=1)
svalue = max((values > 0.).sum(), 0.0000001)

values.sum()/svalue

In [None]:
hiddens = h_labels


L = torch.matmul(hiddens, hiddens.T)
L = F.sigmoid(L)
L = (L - L.diag()).float()
L = F.relu(L)
L = F.normalize(L)
values = L.sum(axis=1)
svalue = max((values > 0.).sum(), 0.0000001)

values.sum()/svalue

In [None]:
L

In [None]:
svalue = max((values > 0.).sum(), 0.)
svalue

In [None]:
svalues = (values > 0.).sum()
values = values.sum()/svalues
values

In [None]:
svalue = torch.max(, 0.002)

In [None]:
(values > 0.).sum()

In [None]:
ReLU