In [1]:
import warnings
warnings.filterwarnings('ignore')

from utils import Dataset, GraphsizePretrained
from tqdm import tqdm
from tqdm.notebook import tqdm
from time import time
import numpy as np

Using backend: pytorch


In [2]:
webkb = Dataset('/home/Documentos/datasets/classification/datasets/acm/')

dataset = webkb

In [3]:
fold = next(dataset.get_fold_instances(10))
fold._fields

('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val')

In [4]:
%%time
graph_builder = GraphsizePretrained(w=2, verbose=True,
                   pretrained_vec='/home/Documentos/Universidade/LBD/pretrained_vectors/glove/glove.6B.300d.txt')
#Gs_train = graph_builder.fit_transform(fold.X_train)
#Gs_val   = graph_builder.transform(fold.X_val)

400000it [00:26, 14963.80it/s]


CPU times: user 26.3 s, sys: 740 ms, total: 27.1 s
Wall time: 26.9 s


In [5]:
import torch
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
from dgl.nn.pytorch.conv import GraphConv, GATConv
from dgl.nn.pytorch.glob import GlobalAttentionPooling

from sklearn.preprocessing import LabelEncoder

from itertools import repeat

import torch.optim as optim
from torch.utils.data import DataLoader

In [6]:
graph_builder.fit(fold.X_train, fold.y_train)

100%|██████████| 19907/19907 [00:05<00:00, 3826.55it/s]


GraphsizePretrained(pretrained_vec='/home/Documentos/Universidade/LBD/pretrained_vectors/glove/glove.6B.300d.txt',
                    verbose=None)

In [7]:
len(graph_builder.g.edges), len(graph_builder.g)

(126449, 34676)

In [8]:
list(map(lambda x: (x,graph_builder.g.degree()[x]), graph_builder.label_ids))

[(0, 2921),
 (1, 8317),
 (2, 10134),
 (3, 15852),
 (4, 1086),
 (5, 5801),
 (6, 5148),
 (7, 14561),
 (8, 13217),
 (9, 3465),
 (10, 11293)]

In [9]:
class GenericGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim,
                 n_heads=8, n_convs=2, drop=.5, first_hidden='emb', attn_drop=.5,
                 encoders={'term','label'}, device='cuda:0'):
        super(GenericGAT, self).__init__()
        self.device = torch.device(device)
        self.first_hidden = first_hidden
        
        self.encoders = nn.ModuleDict({
            k: nn.Linear(in_dim, hidden_dim).to(self.device) for k in encoders
        })
        
        self.layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim, residual=True, num_heads=n_heads, activation=F.leaky_relu,
                    feat_drop=drop, attn_drop=attn_drop).to(self.device) for _ in range(n_convs)
        ])
        self.down_proj = [
            nn.Linear(n_heads*hidden_dim, hidden_dim).to(self.device) for _ in range(n_convs)
        ]
        
    def forward(self, G, **kwargs):
        h = G.ndata[self.first_hidden].float()
        for (k, mask) in kwargs.items():
            if k in self.encoders:
                if mask is not None:
                    h[ mask ] = self.encoders[k]( h[ mask ] )
                else:
                    h = self.encoders[k]( h )
        
        for l, conv in enumerate(self.layers):
            h = conv(G, h)
            h = h.view(h.shape[0], -1)
            h = self.down_proj[l]( h )
        
        return h
        

In [10]:
class TGA(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_class,
                  n_heads=8, drop=.5, attn_drop=.5,
                  device='cuda:0'):
        super(TGA, self).__init__()
        self.n_class = n_class
        self.device = torch.device(device)
        self.gat_global = GenericGAT(in_dim, hidden_dim, 
                                     encoders={'label'}, 
                                     n_heads=n_heads, drop=drop,
                                     attn_drop=attn_drop, device=self.device)
        self.gat_local  = GenericGAT(hidden_dim, hidden_dim, 
                                     encoders={'term'}, 
                                     n_heads=n_heads, drop=drop,
                                     first_hidden='emb',
                                     attn_drop=attn_drop, device=self.device)

        self.lin = nn.Linear( 2*hidden_dim, 1).to(self.device)
        # Depois tentar alguma ativação (ReLU, por exemplo, pode "desativar" alguns termos no softmax)
        self.pooling = GlobalAttentionPooling( self.lin ).to(self.device)

        # Fully Connected
        self.fc1 = nn.Linear( 2*hidden_dim, self.n_class).to(self.device)
        #self.fc2 = nn.Linear( hidden_dim, self.n_class).to(self.device)
        #self.fc3 = nn.Linear( hidden_dim, self.n_class).to(self.device)
    def forward(self, G, gs):
        h_global           = self.gat_global( G, label=G.ndata['label'].nonzero().flatten() )
        #gs.ndata['weight'] = h_global[ gs.ndata['idx'] ] # Tentar concatenando
        h                  = h_global[ gs.ndata['idx'] ]
        h_local            = self.gat_local(gs, term=None)
        h_local            = torch.cat((h, h_local), 1)
        h_docs             = self.pooling( gs, h_local )
        return self.fc1( h_docs )
# torch.Size([3652, 300]) torch.Size([3652, 300]) torch.Size([128, 300])
        

In [11]:
in_dim=300
hidden_dim=300
n_heads=4
drop=0.3
attn_drop=0.5
batch_size=32
device='cuda:0'

In [12]:
model = TGA( in_dim, hidden_dim, graph_builder.n_class,
            n_heads=n_heads, drop=drop, attn_drop=attn_drop )
model

TGA(
  (gat_global): GenericGAT(
    (encoders): ModuleDict(
      (label): Linear(in_features=300, out_features=300, bias=True)
    )
    (layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=300, out_features=1200, bias=False)
        (feat_drop): Dropout(p=0.3, inplace=False)
        (attn_drop): Dropout(p=0.5, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (res_fc): Identity()
      )
      (1): GATConv(
        (fc): Linear(in_features=300, out_features=1200, bias=False)
        (feat_drop): Dropout(p=0.3, inplace=False)
        (attn_drop): Dropout(p=0.5, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (res_fc): Identity()
      )
    )
  )
  (gat_local): GenericGAT(
    (encoders): ModuleDict(
      (term): Linear(in_features=300, out_features=300, bias=True)
    )
    (layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=300, out_features=1200, bias=False)
        (feat_drop): Dropo

In [13]:
def collate(param):
    X, y = zip(*param)
    Gs_nx = graph_builder.transform(X)
    
    Gs_dgl = []
    for g in Gs_nx:
        g_dgl = dgl.DGLGraph()
        g_dgl.from_networkx(g, node_attrs=['emb', 'idx'] )
        Gs_dgl.append( g_dgl )
        
    big_graph_dgl = dgl.DGLGraph()
    big_graph_dgl.from_networkx(graph_builder.g, node_attrs=['emb', 'label', 'idx'] )
    
    return big_graph_dgl, dgl.batch(Gs_dgl), torch.tensor(y)

In [14]:
loss_func = nn.CrossEntropyLoss()

optimizer = optim.Adam( model.parameters(), lr=1e-3, weight_decay=1e-3)
#optimizer = optim.AdamW( model.parameters(), lr=1e-2, weight_decay=1e-3)

#optimizer = optim.RMSprop( model.parameters(), lr=1e-2, weight_decay=1e-4)
#optimizer = optim.RMSprop( model.parameters(), lr=0.0001 )

model.train()
torch.cuda.synchronize()

In [15]:
n_epochs = 3

for epoch in range(n_epochs):
    epoch_loss = 0
    data_loader = DataLoader(list(zip(fold.X_train, fold.y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=collate, num_workers=2)
    with tqdm(total=len(fold.y_train)) as pbar:
        total = 0
        correct = 0
        model.train()
        for G, gs, y in data_loader:
            G = G.to( torch.device('cuda:0') )
            gs = gs.to( torch.device('cuda:0') )
            y = y.to( torch.device('cuda:0') )
            outputs = model( G, gs )
            probs_Y = torch.softmax(outputs, 1)
            sampled_Y = torch.argmax(probs_Y, 1).reshape(-1)
            
            total += y.size(0)
            correct += (sampled_Y == y).sum().item()
            
            # NN backprop phase
            loss = loss_func(outputs, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.detach().item()
            
            pbar.update( len(y) )
            pbar.set_description_str(f'iter {epoch} Acc train: {correct/total:.3}')

HBox(children=(FloatProgress(value=0.0, max=19907.0), HTML(value='')))




RuntimeError: Tensors must have same number of dimensions: got 2 and 3

In [None]:
ydf sdfg sdfg

In [None]:
outputs.shape

In [None]:
G

In [None]:
gs

In [None]:
bigG.ndata['label'].nonzero().flatten()

In [None]:
g = dgl.unbatch(Gs)[0]

In [None]:
Gs.ndata['idx'].max(),Gs.ndata['idx'].shape

In [None]:
bigG.ndata['emb'].shape

In [None]:
bigG.ndata['emb'][Gs.ndata['idx']].shape

In [None]:
gat_global = GenericGAT(in_dim, hidden_dim, 
                        encoders={'term','label'}, 
                        n_heads=n_heads, drop=drop,
                        attn_drop=attn_drop, device=device)
gat_local  = GenericGAT(in_dim, hidden_dim, 
                        encoders={'term'}, 
                        n_heads=n_heads, drop=drop,
                        attn_drop=attn_drop, device=device)
          

In [None]:
loss_func = nn.CrossEntropyLoss()

optimizer = optim.Adam( model.parameters(), lr=1e-5, weight_decay=1e-3)
#optimizer = optim.AdamW( model.parameters(), lr=1e-2, weight_decay=1e-3)

#optimizer = optim.RMSprop( model.parameters(), lr=1e-2, weight_decay=1e-4)
#optimizer = optim.RMSprop( model.parameters(), lr=0.0001 )

torch.cuda.synchronize()

In [None]:
graph_builder.le.classes_

In [None]:
probs_Y = torch.softmax(outputs, 1)
sampled_Y = torch.argmax(probs_Y, 1).reshape(-1)

total += labels.size(0)
correct += (sampled_Y == labels).sum().item()

# NN backprop phase
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()

In [None]:
for epoch in range(n_epochs):
    data_loader = DataLoader(zip(fold.X_train, fold.y_train), batch_size=batch_size,
                             shuffle=True, collate_fn=graph_builder.collate)
    epoch_loss = 0
    with tqdm(total=len(data_loader.dataset), smoothing=0.) as pbar:
        t0 = time()
        total = 0
        correct = 0
        model.train()
        #  g_dgl, labels, node_idx, docs_idx, global_terms_idx, range(nclass)
        for i, (bg, labels, node_idx, docs_idx, global_terms_idx, terms_idx, labels_idx) in enumerate(data_loader):
            # model(G, docs_idx, terms_idx, global_terms_idx, labels_idx)
            outputs, L_hiddens, GT_hiddens, doc_dists = model(bg, docs_idx, terms_idx, global_terms_idx, labels_idx)
            

In [None]:
class TGA(nn.Module):
    def __init__(self, in_dim, hidden_dim, graphsize=None,
                 n_heads=8, drop=.5, attn_drop=.5, device='cuda:0'):
        
        self.device = torch.device(device)
        
        if graphsize is not None:
            self.graphsize = graphsize
        
        self.gat_global = GenericGAT(in_dim, hidden_dim, 
                                     encoders={'term','label'}, 
                                     n_heads=n_heads, drop=drop,
                                     attn_drop=attn_drop, device=self.device)
        
        self.gat_local = GenericGAT(in_dim, hidden_dim,
                                     encoders={'term'}, 
                                     n_heads=n_heads, drop=drop,
                                     attn_drop=attn_drop, device=self.device)
        
        
        self.lin = nn.Linear( hidden_dim, 1).to(self.device)
        self.pooling = GlobalAttentionPooling( self.lin ).to(self.device)
        
    def forward(self, G, gs):
        H = self.gat_global(G, term=self.graphsize.global_term_ids.values(), label=self.graphsize.label_ids)
        h = self.gat_local(gs, term=None)

In [None]:
def fit(X_train, y_train):
    G = graphisize.big_graph( X_train, y_train )
    for e in epochs:
        G = gat_global

In [None]:
le = LabelEncoder()

y_train = le.fit_transform( fold.y_train )
y_val   = le.transform( fold.y_val )

graph_builder.embbeding = { t:v for (t,v) in graph_builder.embeddings_dict.items() }
graph_builder.label_ids = [ y for y in le.classes_ ]
for y in graph_builder.label_ids:
    hotenc = np.zeros(300)
    hotenc[y] = 1
    graph_builder.embbeding[y] = hotenc

In [None]:
def get_big_graph(self, X, y):
    docs = list(map(self.analyzer.build_analyzer(), self.progress_bar(X)))
    edges_to_add = set()
    self.node_mapper = { y: y for y in label_ids }
    for (doc,y) in zip( docs, y_train ):

        doc_in_terms = set(filter( lambda x: x in self.embeddings_dict, doc))
        terms_by_id = list(map(lambda x: self.node_mapper.setdefault(x, len(self.node_mapper)), doc_in_terms))

        list_of_edges = list(map( lambda x: (y, x), terms_by_id ))
        list(map(edges_to_add.add, list_of_edges))
        
    g = nx.Graph()
    g.add_nodes_from( [ (idx, {'emb': self.embbeding[t], 'term': t} ) for (t,idx) in self.node_mapper.items() ] )
    g.add_edges_from( edges_to_add )
    return g
    #g_dgl = dgl.DGLGraph()
    #g_dgl.from_networkx(g, node_attrs=['emb'] )
    #g_dgl = g_dgl.to(torch.device('cuda:0'))
    #return g_dgl

In [None]:
g = get_big_graph(graph_builder, fold.X_train, y_train)

In [None]:
g_dgl

In [None]:
gat = GenericGAT(hidden_dim, hidden_dim, 
        n_heads=4, drop=0.01, attn_drop=0.01, device='cuda:0')

In [None]:
h = gat(g_dgl)
h = h.view(h.shape[0], -1)

In [None]:
h.shape

In [None]:
h[label_ids]

In [None]:
del h, gat

In [None]:
class SamplerCollate(object):
    def __init__(self, label_vectors, terms_vectors, nclass, device='cuda:0'):
        self.terms_vectors = terms_vectors
        self.label_vectors = label_vectors
        self.nclass = nclass
        self.device = torch.device(device)
    def collate(self, samples):
        Gs_Fs, labels = map(list, zip(*samples))
        big_graph = nx.Graph()
        node_idx = { }
        for y in range(self.nclass):
            node_idx.setdefault( ('L', y), len(node_idx) )

        big_graph.add_nodes_from( [ (node_idx[('L',y)], {'idx': node_idx[('L',y)], 'emb': self.label_vectors[y]}) for y in range(self.nclass)] )

        docs_idx = []
        terms_idx = []
        global_terms_idx = set()

        for i,(g,y) in enumerate(list(zip(Gs_Fs, y_train))):
            # Get words idxs and add in the big graph
            nodes = [ ( node_idx.setdefault((i,w), len(node_idx) ), {'idx': node_idx[(i,w)], 'emb': att['emb']}) for w,att in g.nodes(data=True) ]
            big_graph.add_nodes_from( nodes )

            # Build document words idxs
            if len(nodes) > 0:
                nodes,_ = list(zip(*nodes))
                terms_idx.extend( list(nodes) )
            docs_idx.append( list(nodes) )

            # Add terms -> terms edges in the big graph (Local terms co-occurs)
            w_edges = [ (node_idx.setdefault((i,s), len(node_idx) ), node_idx.setdefault((i,t), len(node_idx) )) for (s,t) in g.edges ]
            big_graph.add_edges_from( w_edges )

            # Get global terms nodes format=[(id_global_term,{ key_att: value_att })]
            filtered_nodes = []
            for (w,att) in g.nodes(data=True):
                id_w = node_idx.setdefault(('gt', w), len(node_idx) )
                if id_w not in big_graph:
                    #filtered_nodes.append( ( id_w,{'idx': id_w, 'emb': self.terms_vectors[att['word']]} ) )
                    filtered_nodes.append( ( id_w,{'idx': id_w, 'emb': att['emb']} ) )
            #filtered_nodes = [ ( node_idx.setdefault(('gt', w), len(node_idx) ),{'idx': node_idx[('gt', w)], 'emb': self.terms_vectors['emb']} ) for (w,att) in g.nodes(data=True) if node_idx.setdefault(('gt',w), len(node_idx) ) not in big_graph ]
            big_graph.add_nodes_from( filtered_nodes )

            # Add labels -> terms edges in the big graph
            big_graph.add_edges_from( [(node_idx[('L', y)], node_idx[('gt', w)]) for w in g.nodes] )

            # Add words -> terms edges in the big graph
            big_graph.add_edges_from( [(node_idx[(i, w)], node_idx[('gt', w)]) for w in g.nodes] )

            # Add global terms idx
            if len(filtered_nodes) > 0:
                filtered_nodes,_ = list(zip(*filtered_nodes))
                global_terms_idx = global_terms_idx.union( set(filtered_nodes) )

        g_dgl = dgl.DGLGraph()
        g_dgl.from_networkx(big_graph, node_attrs=['idx','emb'] )

        #g_dgl.ndata['emb'] = torch.FloatTensor(g_dgl.ndata['emb'])
        g_dgl.to(torch.device('cuda:0'))

        labels = torch.tensor(labels).to(self.device)
        
        del big_graph, nodes

        return g_dgl, labels, { v:k for (k,v) in node_idx.items() }, list(docs_idx), list(global_terms_idx), list(terms_idx), range(self.nclass)

In [None]:
class ClassifierGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes,
                 n_heads=8, drop=.5, attn_drop=.5,
                 last=None, dist_func='diff',
                 device='cuda:0'):
        super(ClassifierGAT, self).__init__()
        
        self.device = torch.device(device)

        self.encoder_global_term = nn.Linear(in_dim, hidden_dim).to(self.device)
        self.encoder_label       = nn.Linear(in_dim, hidden_dim).to(self.device)
        self.encoder_term        = nn.Linear(in_dim, hidden_dim).to(self.device)
        
        self.layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim, residual=True, num_heads=n_heads, activation=F.leaky_relu,
                    feat_drop=drop, attn_drop=attn_drop).to(self.device),
            GATConv(hidden_dim, hidden_dim, residual=True, num_heads=n_heads, activation=F.leaky_relu,
                    feat_drop=drop, attn_drop=attn_drop).to(self.device)
        ])
        
        self.pool_labels_old = nn.Linear(2*hidden_dim, hidden_dim).to(self.device)
        self.pool_global_terms_old = nn.Linear(2*hidden_dim, hidden_dim).to(self.device)
        

        self.down_proj = [
            nn.Linear(n_heads*hidden_dim, hidden_dim).to(self.device),
            nn.Linear(n_heads*hidden_dim, hidden_dim).to(self.device)
        ]
        
        
        self.last = last
        
        self._doc_weights_ = self._doc_weights_att_
        
        self.lin = nn.Linear( hidden_dim, 1).to(self.device)
        self.pooling = GlobalAttentionPooling( self.lin ).to(self.device)
        
        if dist_func.lower() == 'cat':
            self._dist_func_ = self.dist_cat_func
            self.att_w = nn.Linear( 2*hidden_dim, 1).to(self.device)
        else:
            self.att_w = nn.Linear( hidden_dim, 1).to(self.device)
            if dist_func.lower() == 'diff':
                self._dist_func_ = self.dist_diff_func
            elif dist_func.lower() == 'mult':
                self._dist_func_ = self.dist_mult_func
            elif dist_func.lower() == 'pool':
                self._doc_weights_ = self._doc_weights_pool_att_
                self.att_w = nn.Linear( 2*hidden_dim, 1).to(self.device)
                self._dist_func_ = self.dist_mult_func
                
    
    def forward(self, G, docs_idx, terms_idx, global_terms_idx, labels_idx):
        
        self.n_class = len(labels_idx)
        self.n_docs_batch = len(docs_idx)
  
        h = G.ndata['emb'].float()
        h[global_terms_idx] = self.encoder_global_term( h[global_terms_idx] )
        h[labels_idx]       = self.encoder_label( h[labels_idx] )
        h[terms_idx]        = self.encoder_term( h[terms_idx] )
        
        
        old_labels_hiddens       = h[labels_idx]
        old_global_terms_hiddens = h[global_terms_idx]
        
        G.ndata['emb'] = h
        
        for l, conv in enumerate(self.layers):
            h = conv(G, h)
            h = h.view(h.shape[0], -1)
            h = self.down_proj[l]( h )
        
        Gs = list(map(G.subgraph, docs_idx))
        batch = dgl.batch( Gs )
        
        docs_hiddens = self.pooling( batch, h[terms_idx] )
        labels_hiddens = h[labels_idx]
        global_terms_idx_hiddens = h[global_terms_idx]
        
        concated_labels = torch.cat( (old_labels_hiddens, labels_hiddens), 1 )
        label_hiddens   = self.pool_labels_old( concated_labels )
        
        concated_global_terms    = torch.cat( (old_global_terms_hiddens, global_terms_idx_hiddens), 1 )
        global_terms_idx_hiddens = self.pool_global_terms_old( concated_global_terms )
        
        weights, doc_dists = self._doc_weights_( docs_hiddens, labels_hiddens )
        
        if self.last is not None:
            weights = self.last(weights)
            
        del docs_hiddens, batch, Gs, old_labels_hiddens
        
        return weights, labels_hiddens, global_terms_idx_hiddens, doc_dists
    def _doc_weights_att_(self, docs_hiddens, labels_hiddens):
        
        result = []
        for d in range(self.n_docs_batch):
            doc_vec = docs_hiddens[d]
            for c in range(self.n_class):
                class_vec = labels_hiddens[c]
                
                dist = self._dist_func_( doc_vec, class_vec )
                result.append(dist)
                
        doc_dists = torch.stack( result )
        unstacked_weights = self.att_w( doc_dists )
        weights = torch.reshape( unstacked_weights, (self.n_docs_batch, self.n_class) )
        
        del result, unstacked_weights
        
        return weights, doc_dists
    
    def _doc_weights_pool_att_(self, docs_hiddens, labels_hiddens):
        
        result = []
        labels_hiddens = labels_hiddens.flatten()
        for d in range(self.n_docs_batch):
            doc_vec = docs_hiddens[d]
            for c in range(self.n_class):
                class_vec = labels_hiddens[c]
                dist      = self._dist_func_( doc_vec, class_vec )
                dist      = torch.cat((dist, labels_hiddens))
            result.append(dist)
            dist = self._dist_func_( doc_vec, class_vec )
                
        doc_dists = torch.stack( result )
        weights = self.att_w( doc_dists )
        
        return weights, doc_dists
        
    def dist_diff_func(self, vec1, vec2):
        return (vec1-vec2)
        
    def dist_mult_func(self, vec1, vec2):
        return (vec1*vec2)
        
    def dist_cat_func(self, vec1, vec2):
        return torch.cat((vec1, vec2))


In [None]:
if model:
    del model

In [None]:
model = ClassifierGAT(graph_builder.ndim, hidden_dim, dataset.nclass,
                      n_heads=n_heads, drop=.0, attn_drop=.0, dist_func='cat').to(torch.device('cuda:0'))

In [None]:
best_score = None
n_iters = 0
n_epochs = 3



sampler = SamplerCollate(
    label_vectors = { k: np.random.uniform( size=graph_builder.ndim ) for k in le.classes_ },
    terms_vectors = graph_builder.embeddings_dict.copy(),
    nclass = len(le.classes_)
)

for epoch in range(n_epochs):
    data_loader = DataLoader(list(zip(Gs_train, y_train)), batch_size=batch_size,
                             shuffle=True, collate_fn=sampler.collate)
    epoch_loss = 0
    with tqdm(total=len(data_loader.dataset), smoothing=0.) as pbar:
        t0 = time()
        total = 0
        correct = 0
        model.train()
        #  g_dgl, labels, node_idx, docs_idx, global_terms_idx, range(nclass)
        for i, (bg, labels, node_idx, docs_idx, global_terms_idx, terms_idx, labels_idx) in enumerate(data_loader):
            # model(G, docs_idx, terms_idx, global_terms_idx, labels_idx)
            outputs, L_hiddens, GT_hiddens, doc_dists = model(bg, docs_idx, terms_idx, global_terms_idx, labels_idx)
            
            probs_Y = torch.softmax(outputs, 1)
            sampled_Y = torch.argmax(probs_Y, 1).reshape(-1)
            
            total += labels.size(0)
            correct += (sampled_Y == labels).sum().item()
            
            # NN backprop phase
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.detach().item()
            
            L_hiddens = L_hiddens.detach().cpu().numpy()
            for l in sampler.label_vectors.keys():
                sampler.label_vectors[l] = L_hiddens[l]
                
            terms = [ graph_builder.vocab_idx[node_idx[gt_id][1]] for gt_id in global_terms_idx ]
            GT_hiddens = GT_hiddens.detach().cpu().numpy()
            for i,t in enumerate(terms):
                sampler.terms_vectors[t] = GT_hiddens[i]
            
            pbar.update( len(labels) )
            pbar.set_description_str(f'iter {epoch} Acc train: {correct/total:.3}')
            
            del bg, labels, node_idx, docs_idx, global_terms_idx, terms_idx, labels_idx
            del outputs, L_hiddens, GT_hiddens, doc_dists
            del probs_Y, sampled_Y
            del loss, terms
            #break
        #break

In [None]:
laxg srgs grdaels

In [None]:
GT_hiddens

In [None]:
L_hidden_flatten = L_hiddens.flatten()

In [None]:
from collections import Counter

In [None]:
bg

In [None]:
bg.edges()

In [None]:
bg.nodes()

In [None]:
doc_dists

In [None]:

#bg, labels, node_idx, docs_idx, global_terms_idx, terms_idx, labels_idx