In [1]:
import torch, dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dgl.data import DGLDataset
import dgl.function as fn
import re
import numpy as np
import pandas as pd
import os.path as osp
from glob import glob
from copy import deepcopy
from colorama import Fore
from random import shuffle
from typing import Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

from graphSAGE import GraphSAGE
# from adabelief_pytorch import AdaBelief
from focal_loss import FocalLoss, focal_loss
from dkkd_create_graph import classes, class_2_idx, class_to_idx

Some weights of the model checkpoint at vinai/phobert-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


##### 2. Graph Dataset

In [3]:
class DkkdGraphDataset(DGLDataset):
    def __init__(self, root:str='/home/agent/Documents/graph/GNN/dataset/DKKD_graph'):
        super().__init__(name='dataset/DKKD_graph')
        self.root = root
        self.edges = glob(osp.join(root, '*.edges.csv'))
        # self.nodes_feat = glob(osp.join(root, '*.nfeat.npy'))
        # self.nodes_label = glob(osp.join(root, '*.idx.csv'))
        self.shuffle()
    
    @staticmethod
    def _get_n_nodes(nodes_label:pd.DataFrame) -> int:
        r"""
        tính và kiểm tra số thứ tự của node
        """
        n_nodes = nodes_label['Id'].to_list()
        for i, idx in enumerate(n_nodes):
            assert i == idx, 'i != idx'
        return len(n_nodes)
    
    def __len__(self): return len(self.edges)
    
    def __getitem__(self, i) -> dgl.DGLGraph:
        edgep = self.edges[i]
        nodes_feat = np.load(re.sub('.edges.csv$', '.nfeat.npy', edgep))
        nodes_label = pd.read_csv(
            re.sub('.edges.csv$', '.idx.csv', edgep), encoding='utf-8')
        n_nodes = self._get_n_nodes(nodes_label)
        
        nodes_label = nodes_label['label'].astype('category').cat.codes.to_list()
        edge = pd.read_csv(edgep, encoding='utf-8')
        
        g = dgl.graph((edge['src'], edge['dst']), num_nodes=n_nodes)
        g = dgl.to_bidirected(g)
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
        g.ndata['feat' ] = torch.from_numpy(nodes_feat )
        g.ndata['label'] = torch.tensor    (nodes_label)
        # g.ndata['train_mask'] = torch.ones (n_nodes, dtype=torch.bool)
        # g.ndata['val_mask'  ] = torch.zeros(n_nodes, dtype=torch.bool)
        # g.ndata['test_mask' ] = torch.zeros(n_nodes, dtype=torch.bool)
        
        return g
    
    def shuffle(self): shuffle(self.edges)
        
    def process(self): ...
    
train_data = DkkdGraphDataset(root='dataset/DKKD_graph')
val_data = DkkdGraphDataset(root='dataset/DKKD_graph_test')


##### 3. Training function

In [4]:
def cacu_alpha() -> torch.Tensor:
    df = pd.read_csv('dataset/DKKD/classes_unbalance.csv', encoding='utf-8')
    classes = df['class_name']
    num = df['num']
    stat_cls = []
    for cls, n_node in zip(classes, num):
        idx = class_2_idx(cls)
        stat_cls.append((idx, n_node))
        
    stat_cls = np.array(sorted(stat_cls, key=lambda x:x[0]))
    tong = 1.*stat_cls[:,1].sum()
    freq = stat_cls[:,1]/tong
    # alpha = 1/freq
    alpha = 1.0 - freq
    return torch.from_numpy(alpha)
alpha = cacu_alpha()
print(alpha)

tensor([0.9888, 0.9505, 0.9949, 0.9913, 0.9944, 0.9973, 0.9901, 0.9851, 0.9859,
        0.9840, 0.9851, 0.9543, 0.9950, 0.9968, 0.9975, 0.9988, 0.9945, 0.9973,
        0.9916, 0.9913, 0.9809, 0.9885, 0.9901, 0.9975, 0.9942, 0.9974, 0.9735,
        0.9631, 0.9944, 0.9816, 0.9802, 0.9525, 0.9943, 0.9944, 0.9944, 0.9944,
        0.9972, 0.9836, 0.9972, 0.9801, 0.9854, 0.9876, 0.9934, 0.9727, 0.9777,
        0.9861, 0.9941, 0.6093], dtype=torch.float64)


In [5]:
@torch.no_grad()
def _val(val_dataset:DGLDataset, model:nn.Module) -> float:
    model.eval()
    acc = 0.0
    for g in val_dataset:
        features = g.ndata['feat'].float()
        labels = g.ndata['label']
        logits = model(g, features)
        pred = logits.argmax(1)
        acc += (pred == labels).float().mean()
    return acc/len(val_dataset)

@torch.no_grad()
def val(val_dataset:DGLDataset, model:nn.Module, 
                        ignore_class:Optional[Tuple[int, ...]]=None) -> float:
    model.eval()
    if ignore_class is None: 
        return _val(val_dataset, model)
    
    acc = 0.0
    for g in val_dataset:
        features = g.ndata['feat'].float()
        labels = g.ndata['label']
        logits = model(g, features)
        pred = logits.argmax(1)
        
        filter_cls = [y for y, lbl in zip(pred == labels, labels) 
                      if lbl not in ignore_class]
        acc += torch.tensor(filter_cls).float().mean()
    return acc / len(val_dataset)
    
    
def train_n_valid(train_dataset:DGLDataset, val_dataset:DGLDataset, 
                  model:nn.Module, epochs:int, lr=0.001, path_save:str=...):
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
    # optimizer = AdaBelief(model.parameters(), lr=lr, betas=(0.9,0.999), eps=1e-8, 
                        #   rectify = False, print_change_log=False)
    criteron:FocalLoss = focal_loss(alpha= alpha, gamma=4.2)
    best_acc = -1.0

    for epoch in range(1, epochs+1):
        model.train()
        acc =0.
        for g in train_dataset:
            features = g.ndata['feat'].float()
            labels = g.ndata['label']
            
            logits = model(g, features)
            pred = logits.argmax(1)
            
            loss = criteron(logits, labels)
            acc += (pred == labels).float().mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_dataset.shuffle()
        
        train_acc = (1.*acc)/len(train_dataset)
        # val_acc = val(val_dataset, model)
        val_acc = val(val_dataset, model, None)
        if val_acc >= best_acc:
            best_acc = val_acc
            best_model = deepcopy(model)
            torch.save(model.state_dict(), path_save)
        

        print('Epoch {:<3d}: loss: {:.4f}, best {:.4f}, train_acc: {:.4f}, val_acc: {:.4f}'.
              format(epoch, loss, best_acc, train_acc, val_acc))

    torch.save(best_model.state_dict(), path_save)
    print('Last epoch {:<3d}: loss: {:.4f}, best {:.4f}, train_acc: {:.4f}, val_acc: {:.4f}'.
          format(epoch, loss, best_acc, train_acc, val_acc))
    
    return best_model, best_acc

In [6]:
num_classes = len(classes)  #48
print('num classes = ', num_classes)
model = GraphSAGE(772, 196, 64, num_classes) #
model.load_state_dict(torch.load('weights/graphSAGE_best40.pth',
                                 map_location=torch.device('cpu')))

num classes =  48


<All keys matched successfully>

In [7]:
best_model, best_acc = train_n_valid(train_data, val_data, model, 2000, lr=0.001, path_save='weights/graphSAGE_best40.pth')
torch.save(best_model.state_dict(), 'weights/graphSAGE_best40.pth')
print(best_acc)

Epoch 1  : loss: 0.2231, best 0.4267, train_acc: 0.7701, val_acc: 0.4267
Epoch 2  : loss: 0.0751, best 0.4267, train_acc: 0.7911, val_acc: 0.4207
Epoch 3  : loss: 0.2152, best 0.4267, train_acc: 0.8045, val_acc: 0.4180
Epoch 4  : loss: 0.1915, best 0.4267, train_acc: 0.8030, val_acc: 0.3899
Epoch 5  : loss: 0.2559, best 0.4267, train_acc: 0.7965, val_acc: 0.3787
Epoch 6  : loss: 0.2298, best 0.4267, train_acc: 0.8080, val_acc: 0.4139
Epoch 7  : loss: 0.2250, best 0.4267, train_acc: 0.7915, val_acc: 0.4045
Epoch 8  : loss: 0.0000, best 0.4267, train_acc: 0.7861, val_acc: 0.3747
Epoch 9  : loss: 0.1515, best 0.4267, train_acc: 0.7864, val_acc: 0.3805
Epoch 10 : loss: 0.1181, best 0.4267, train_acc: 0.7912, val_acc: 0.4018
Epoch 11 : loss: 0.2254, best 0.4267, train_acc: 0.8022, val_acc: 0.3943
Epoch 12 : loss: 0.3840, best 0.4267, train_acc: 0.7853, val_acc: 0.4153
Epoch 13 : loss: 0.3788, best 0.4267, train_acc: 0.7943, val_acc: 0.4034
Epoch 14 : loss: 0.1896, best 0.4267, train_acc: 0.