# Test of GraphSAGE
- use DGL
- predict `graphs`
- valid, test data are in the training dataset

In [2]:
import os
import dgl
import csv
import json
import torch
import random
import subprocess
import torch as th
import numpy as np
import pandas as pd
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
from torch.optim import AdamW, lr_scheduler
from dgl.nn import GraphConv, GATConv, SAGEConv
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

- check the GPU and assign the GPU by the best memory usage

In [3]:
# def get_free_gpu():
#     try:
#         # Run nvidia-smi command to get GPU details
#         _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
#         command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
#         memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
#         memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
#         # Get the GPU with the maximum free memory
#         best_gpu_id = memory_free_values.index(max(memory_free_values))
#         return best_gpu_id
#     except:
#         # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
#         return 0

# if torch.cuda.is_available():
#     # Get the best GPU ID based on free memory and set it
#     best_gpu_id = get_free_gpu()
#     device = torch.device(f"cuda:{best_gpu_id}")
# else:
#     device = torch.device("cpu")
#     print("there's no available GPU")

# # device = torch.device(f"cuda:{1}")
# print(device)


## Fix the seed

In [4]:
#fix seed
def same_seeds(seed = 8787):
    torch.manual_seed(seed)
    # random.seed(seed) 
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Load the embedding

In [5]:
DIM = 256
embedding = 'transH'

embedding = f'{embedding}_{DIM}'
with open(f"../../data/4_embedding/{embedding}.vec.json", "r") as f:
    tmp = json.load(f)

index2entemb = {idx:emb for idx, emb in enumerate(tmp["ent_embeddings.weight"])}
index2relemb = {idx:emb for idx, emb in enumerate(tmp["rel_embeddings.weight"])}

In [6]:
len(index2entemb)

824642

In [7]:
len(index2relemb)

27

- this file is 55 GB -> takes about 1 min to load it

In [8]:
with open("../../data/all_graph_data.jsonl", "r") as f:
    print("Loading the data...")
#     input_data = list(f)
#     input_data = [json.loads(line) for idx, line in tqdm(f, desc="Loading")]

    input_data = []
    for idx, line in tqdm(enumerate(f), total=16900, desc="Loading"):
        input_data.append(json.loads(line))
        
    print("FINISH...")

Loading the data...


Loading:   0%|          | 0/16900 [00:00<?, ?it/s]

FINISH...


In [9]:
len(input_data)

16900

- Convert the 'node_feat' and 'edge_attr' from int to embedding
    - takes about 45 min to transform the embedding
    - if use original method -> takes about 60 hours

In [10]:
type(input_data)

list

In [11]:
type(input_data[0])

dict

- need to get the new graph.jsonl -> the id is not corresponding

In [12]:
# ============ If type(input_data[0] == dict) ============
for data_point in tqdm(input_data):
    data_point['node_feat'] = [index2entemb[node_id] for node_id in data_point['node_feat']]
    data_point['edge_attr'] = [index2relemb[edge_id] for edge_id in data_point['edge_attr']]

  0%|          | 0/16900 [00:00<?, ?it/s]

In [13]:
type(input_data[0])

dict

In [14]:
len(input_data[1]['node_feat'][0])

256

## Data Loader

In [15]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        return data

def collate(samples):
    data_list = samples
    batched_graphs = []
    for data in data_list:
        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"])

        g.ndata['feat'] = th.tensor(data["node_feat"])
        g.edata['feat'] = th.tensor(data["edge_attr"])
        g.edata['label'] = th.tensor(data["labels"])  # Add edge labels to graph

        batched_graphs.append(g)
    
    return dgl.batch(batched_graphs)

In [16]:
# split 8:1:1 (train, valid, test)
train_data, test_data = train_test_split(input_data, test_size=0.2, random_state=42)
valid_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)


dataset_data = {
    'train': GraphDataset(train_data, device),
    'valid': GraphDataset(valid_data, device),
    'test': GraphDataset(test_data, device)
}

print("Datasets loaded and ready for training!")

Datasets loaded and ready for training!


- choose batch size

In [18]:
def create_dataloaders(batch_size, shuffle=True):
    dataloaders = {}
    for dataset_name, dataset in dataset_data.items():
        # do not shuffle the testing dataset
        if dataset_name == "test":
            dataloaders[dataset_name] = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)    
        else:
            dataloaders[dataset_name] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate)
    return dataloaders

dataloaders = create_dataloaders(16)

- Turn the print message to a log file

In [19]:
import datetime

now = datetime.datetime.now()

formatted_time = now.strftime("%m%d_%H:%M")

log_file_path = f"./log_message/{formatted_time}_GraphSAGE_{embedding}.log"

def add_log_msg(msg, log_file_path=log_file_path):
    with open(log_file_path, 'a') as f:
        f.write(f'{datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}# {msg}\n')
    print(f'{datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}# {msg}')

print(log_file_path)

./log_message/0204_07:08_GraphSAGE_transH_256_256.log


### Model

In [20]:
class GraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(GraphSAGE, self).__init__()
        self.layer1 = dglnn.SAGEConv(in_dim, hidden_dim, 'pool')
        self.layer2 = dglnn.SAGEConv(hidden_dim, out_dim, 'pool')
        self.dropout = nn.Dropout(0.25)

    def forward(self, g, inputs):
        h = self.layer1(g, inputs)
        h = torch.relu(h)
        h = self.dropout(h)
        h = self.layer2(g, h)
        return h

In [21]:
class MLPPredictor(nn.Module):
    def __init__(self, out_feats, out_classes):
        super().__init__()
        self.W = nn.Linear(out_feats*2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

In [22]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_classes):
        super().__init__()
        self.sage = GraphSAGE(in_features, hidden_features, out_features)
        self.pred = MLPPredictor(out_features, num_classes)
      
    def forward(self, g, node_feat, return_logits=False):
        h = self.sage(g, node_feat)
        logits = self.pred(g, h)
        
        return logits

- Model Forward  

In [23]:
def model_fn(batched_g, model, criterion, device, count=1, which_type='train'):
    """Forward a batch through the model."""
#     batched_g, labels = data
    batched_g = batched_g.to(device)
    
    labels = batched_g.edata['label'].to(device)
    
    logits = model(batched_g, batched_g.ndata['feat'].float())

    loss = criterion(logits, labels)

    output = torch.softmax(logits, dim=1)
    preds = output.argmax(1)
    
    # Compute accuracy
    accuracy = torch.mean((preds == labels).float())
        
    return loss, accuracy, preds

In [24]:
import re

def build_dictionary(file_path):
    with open(file_path, 'r') as file:
        next(file)
        # 使用正则表达式去除行末的数字
        dictionary = {re.sub(r'\s\d+$', '', line.strip()): index for index, line in enumerate(file)}
    return dictionary
    
file_path = '../../data/3_openKE/label2id.txt'  # 替換為您檔案的路徑
label2index = build_dictionary(file_path)
index2label = {v: k for k, v in label2index.items()}

index2label

{0: 'T1059.001_702bfdd2-9947-4eda-b551-c3a1ea9a59a2_B',
 1: 'T1078.001_d0ca00832890baa1d42322cf70fcab1a_B',
 2: 'T1074.001_e6dfc7e89359ac6fa6de84b0e1d5762e_B',
 3: 'T1491_68235976-2404-42a8-9105-68230cfef562_B',
 4: 'T1016_14a21534-350f-4d83-9dd7-3c56b93a0c17_B',
 5: 'T1491_47d08617-5ce1-424a-8cc5-c9c978ce6bf9_I',
 6: 'T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I',
 7: 'T1040_6881a4589710d53f0c146e91db513f01_B',
 8: 'T1547.009_b6e5c895c6709fe289352ee23f062229_B',
 9: 'T1564.001_66a5fd5f244819181f074dd082a28905_B',
 10: 'T1047_f4b0b4129560ea66f9751275e82f6bab_B',
 11: 'T1112_257313a3c93e3bb7dfb60d6753b09e34_I',
 12: 'T1047_ac2764f7a67a9ce92b54e8e59b361838_B',
 13: 'T1518.001_33a24ff44719e6ac0614b58f8c9a7c72_B',
 14: 'T1204.002_522f3f35cd013e63830fa555495a0081_I',
 15: 'T1059.001_ccdb8caf-c69e-424b-b930-551969450c57_B',
 16: 'T1105_0856c235a1d26113d4f2d92e39c9a9f8_B',
 17: 'T1547_fe9eeee9a7b339089e5fa634b08522c1_I',
 18: 'T1574.001_63bbedafba2f541552ac3579e9e3737b_B',
 19: 'T1137.002

In [25]:
type(index2entemb)

dict

### Main Training Loop

- For release the GPU memory
    - no need to restart the kernel

In [26]:
# # For release the GPU memory
# # No need to restart the kernel

# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [29]:
seed = 5269
in_dim = DIM # dimension of the node feature
hidden_dim = 512
out_dim = 1024
num_classes = len(label2index)

lr = 5e-4

total_steps = 100
patience = 10
waiting = 0

In [30]:
model = Model(in_dim, hidden_dim, out_dim, num_classes)
best_model_path = f"./checkpoint_graphSAGE/best_model_GraphSAGE_{embedding}-small_batchsize-bigdim.pt"

optimizer = AdamW(model.parameters(), lr)

scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1, verbose=False)

criterion = nn.CrossEntropyLoss()
# criterion = torch.nn.BCEWithLogitsLoss()

In [31]:
same_seeds(seed)
model = model.to(device)
best_val_loss = float('inf')

# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    for batched_g in tqdm(dataloaders['train'], desc="Training", position=0, leave=True):
        num_batches += 1
        loss, accuracy, _ = model_fn(batched_g, model, criterion, device, num_batches, which_type='train')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    # scheduler affect the performance a lot
    # scheduler.step()
    add_log_msg(f"total batches: {num_batches}")

    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    add_log_msg(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0


    with torch.no_grad():
        for batched_g in tqdm(dataloaders['valid'], desc="Validation", position=0, leave=True):
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device, num_batches, which_type='validation')
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    current_loss = total_loss / num_batches
    
    add_log_msg(f'Validation Loss: {current_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}\n')
    
            
    if current_loss < best_val_loss:
        best_val_loss = current_loss
        waiting = 0
        
        if os.path.exists(best_model_path):
            os.remove(best_model_path)
            add_log_msg("Find a better model!!")

        torch.save(model.state_dict(), best_model_path)
 
    else:
        waiting += 1
        if waiting >= patience:
            add_log_msg("============================== Early stopping ==================================")
            break

  0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:11:12# total batches: 845
02/04/2024, 07:11:12# Epoch 0 | Train Loss: 1.1018 | Train Accuracy: 0.7962


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:11:27# Validation Loss: 0.8183 | Validation Accuracy: 0.8287



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:13:56# total batches: 845
02/04/2024, 07:13:56# Epoch 1 | Train Loss: 0.5481 | Train Accuracy: 0.8928


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:14:11# Validation Loss: 0.4665 | Validation Accuracy: 0.9072

02/04/2024, 07:14:11# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:16:40# total batches: 845
02/04/2024, 07:16:40# Epoch 2 | Train Loss: 0.3756 | Train Accuracy: 0.9232


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:16:56# Validation Loss: 0.3539 | Validation Accuracy: 0.9273

02/04/2024, 07:16:56# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:19:24# total batches: 845
02/04/2024, 07:19:24# Epoch 3 | Train Loss: 0.2943 | Train Accuracy: 0.9375


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:19:39# Validation Loss: 0.2990 | Validation Accuracy: 0.9330

02/04/2024, 07:19:39# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:22:08# total batches: 845
02/04/2024, 07:22:08# Epoch 4 | Train Loss: 0.2485 | Train Accuracy: 0.9477


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:22:23# Validation Loss: 0.2180 | Validation Accuracy: 0.9546

02/04/2024, 07:22:23# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:24:51# total batches: 845
02/04/2024, 07:24:51# Epoch 5 | Train Loss: 0.1971 | Train Accuracy: 0.9573


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:25:06# Validation Loss: 0.1988 | Validation Accuracy: 0.9574

02/04/2024, 07:25:06# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:27:33# total batches: 845
02/04/2024, 07:27:33# Epoch 6 | Train Loss: 0.1828 | Train Accuracy: 0.9593


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:27:48# Validation Loss: 0.1861 | Validation Accuracy: 0.9578

02/04/2024, 07:27:48# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:30:16# total batches: 845
02/04/2024, 07:30:16# Epoch 7 | Train Loss: 0.1587 | Train Accuracy: 0.9637


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:30:32# Validation Loss: 0.1693 | Validation Accuracy: 0.9593

02/04/2024, 07:30:32# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:33:00# total batches: 845
02/04/2024, 07:33:00# Epoch 8 | Train Loss: 0.1495 | Train Accuracy: 0.9650


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:33:15# Validation Loss: 0.1553 | Validation Accuracy: 0.9636

02/04/2024, 07:33:15# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:35:42# total batches: 845
02/04/2024, 07:35:42# Epoch 9 | Train Loss: 0.1465 | Train Accuracy: 0.9647


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:35:57# Validation Loss: 0.1461 | Validation Accuracy: 0.9642

02/04/2024, 07:35:57# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:38:24# total batches: 845
02/04/2024, 07:38:24# Epoch 10 | Train Loss: 0.1271 | Train Accuracy: 0.9691


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:38:39# Validation Loss: 0.1306 | Validation Accuracy: 0.9673

02/04/2024, 07:38:39# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:41:07# total batches: 845
02/04/2024, 07:41:07# Epoch 11 | Train Loss: 0.1213 | Train Accuracy: 0.9700


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:41:22# Validation Loss: 0.1317 | Validation Accuracy: 0.9660



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:43:48# total batches: 845
02/04/2024, 07:43:48# Epoch 12 | Train Loss: 0.1190 | Train Accuracy: 0.9703


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:44:03# Validation Loss: 0.1278 | Validation Accuracy: 0.9685

02/04/2024, 07:44:03# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:46:32# total batches: 845
02/04/2024, 07:46:32# Epoch 13 | Train Loss: 0.1092 | Train Accuracy: 0.9721


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:46:48# Validation Loss: 0.1259 | Validation Accuracy: 0.9676

02/04/2024, 07:46:48# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:49:17# total batches: 845
02/04/2024, 07:49:17# Epoch 14 | Train Loss: 0.1172 | Train Accuracy: 0.9707


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:49:32# Validation Loss: 0.1209 | Validation Accuracy: 0.9686

02/04/2024, 07:49:32# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:52:01# total batches: 845
02/04/2024, 07:52:01# Epoch 15 | Train Loss: 0.1058 | Train Accuracy: 0.9725


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:52:16# Validation Loss: 0.1266 | Validation Accuracy: 0.9687



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:54:48# total batches: 845
02/04/2024, 07:54:48# Epoch 16 | Train Loss: 0.1090 | Train Accuracy: 0.9724


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:55:03# Validation Loss: 0.1187 | Validation Accuracy: 0.9701

02/04/2024, 07:55:03# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 07:57:32# total batches: 845
02/04/2024, 07:57:32# Epoch 17 | Train Loss: 0.1002 | Train Accuracy: 0.9739


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 07:57:47# Validation Loss: 0.0999 | Validation Accuracy: 0.9736

02/04/2024, 07:57:47# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:00:18# total batches: 845
02/04/2024, 08:00:18# Epoch 18 | Train Loss: 0.0928 | Train Accuracy: 0.9753


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:00:33# Validation Loss: 0.1262 | Validation Accuracy: 0.9673



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:03:06# total batches: 845
02/04/2024, 08:03:06# Epoch 19 | Train Loss: 0.0924 | Train Accuracy: 0.9753


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:03:21# Validation Loss: 0.1131 | Validation Accuracy: 0.9703



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:05:55# total batches: 845
02/04/2024, 08:05:55# Epoch 20 | Train Loss: 0.0842 | Train Accuracy: 0.9775


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:06:10# Validation Loss: 0.1114 | Validation Accuracy: 0.9689



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:08:38# total batches: 845
02/04/2024, 08:08:38# Epoch 21 | Train Loss: 0.0827 | Train Accuracy: 0.9779


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:08:53# Validation Loss: 0.1115 | Validation Accuracy: 0.9714



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:11:21# total batches: 845
02/04/2024, 08:11:21# Epoch 22 | Train Loss: 0.0797 | Train Accuracy: 0.9786


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:11:36# Validation Loss: 0.1080 | Validation Accuracy: 0.9718



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:14:05# total batches: 845
02/04/2024, 08:14:05# Epoch 23 | Train Loss: 0.0777 | Train Accuracy: 0.9791


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:14:20# Validation Loss: 0.0950 | Validation Accuracy: 0.9747

02/04/2024, 08:14:20# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:16:49# total batches: 845
02/04/2024, 08:16:49# Epoch 24 | Train Loss: 0.0768 | Train Accuracy: 0.9795


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:17:05# Validation Loss: 0.0998 | Validation Accuracy: 0.9736



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:19:33# total batches: 845
02/04/2024, 08:19:33# Epoch 25 | Train Loss: 0.0780 | Train Accuracy: 0.9788


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:19:49# Validation Loss: 0.0877 | Validation Accuracy: 0.9762

02/04/2024, 08:19:49# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:22:18# total batches: 845
02/04/2024, 08:22:18# Epoch 26 | Train Loss: 0.0736 | Train Accuracy: 0.9802


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:22:33# Validation Loss: 0.1077 | Validation Accuracy: 0.9728



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:25:03# total batches: 845
02/04/2024, 08:25:03# Epoch 27 | Train Loss: 0.0692 | Train Accuracy: 0.9812


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:25:18# Validation Loss: 0.0959 | Validation Accuracy: 0.9745



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:27:49# total batches: 845
02/04/2024, 08:27:49# Epoch 28 | Train Loss: 0.0671 | Train Accuracy: 0.9817


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:28:04# Validation Loss: 0.1010 | Validation Accuracy: 0.9743



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:30:32# total batches: 845
02/04/2024, 08:30:32# Epoch 29 | Train Loss: 0.0659 | Train Accuracy: 0.9817


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:30:47# Validation Loss: 0.0985 | Validation Accuracy: 0.9739



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:33:16# total batches: 845
02/04/2024, 08:33:16# Epoch 30 | Train Loss: 0.0699 | Train Accuracy: 0.9813


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:33:31# Validation Loss: 0.1052 | Validation Accuracy: 0.9727



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:35:58# total batches: 845
02/04/2024, 08:35:58# Epoch 31 | Train Loss: 0.0679 | Train Accuracy: 0.9816


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:36:13# Validation Loss: 0.1049 | Validation Accuracy: 0.9732



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:38:41# total batches: 845
02/04/2024, 08:38:41# Epoch 32 | Train Loss: 0.0625 | Train Accuracy: 0.9832


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:38:56# Validation Loss: 0.0943 | Validation Accuracy: 0.9753



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:41:27# total batches: 845
02/04/2024, 08:41:27# Epoch 33 | Train Loss: 0.0585 | Train Accuracy: 0.9838


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:41:42# Validation Loss: 0.0899 | Validation Accuracy: 0.9770



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:44:12# total batches: 845
02/04/2024, 08:44:12# Epoch 34 | Train Loss: 0.0569 | Train Accuracy: 0.9847


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:44:27# Validation Loss: 0.0844 | Validation Accuracy: 0.9784

02/04/2024, 08:44:27# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:46:56# total batches: 845
02/04/2024, 08:46:56# Epoch 35 | Train Loss: 0.0550 | Train Accuracy: 0.9849


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:47:12# Validation Loss: 0.0876 | Validation Accuracy: 0.9774



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:49:42# total batches: 845
02/04/2024, 08:49:42# Epoch 36 | Train Loss: 0.0612 | Train Accuracy: 0.9838


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:49:57# Validation Loss: 0.0892 | Validation Accuracy: 0.9770



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:52:26# total batches: 845
02/04/2024, 08:52:26# Epoch 37 | Train Loss: 0.0569 | Train Accuracy: 0.9845


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:52:41# Validation Loss: 0.1079 | Validation Accuracy: 0.9724



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:55:10# total batches: 845
02/04/2024, 08:55:10# Epoch 38 | Train Loss: 0.0597 | Train Accuracy: 0.9845


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:55:26# Validation Loss: 0.0944 | Validation Accuracy: 0.9753



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 08:57:56# total batches: 845
02/04/2024, 08:57:56# Epoch 39 | Train Loss: 0.0497 | Train Accuracy: 0.9865


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 08:58:12# Validation Loss: 0.0914 | Validation Accuracy: 0.9772



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:00:40# total batches: 845
02/04/2024, 09:00:40# Epoch 40 | Train Loss: 0.0481 | Train Accuracy: 0.9868


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:00:55# Validation Loss: 0.1009 | Validation Accuracy: 0.9739



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:03:24# total batches: 845
02/04/2024, 09:03:24# Epoch 41 | Train Loss: 0.0492 | Train Accuracy: 0.9864


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:03:39# Validation Loss: 0.0831 | Validation Accuracy: 0.9798

02/04/2024, 09:03:39# Find a better model!!


Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:06:06# total batches: 845
02/04/2024, 09:06:06# Epoch 42 | Train Loss: 0.0443 | Train Accuracy: 0.9878


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:06:21# Validation Loss: 0.0948 | Validation Accuracy: 0.9759



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:08:50# total batches: 845
02/04/2024, 09:08:50# Epoch 43 | Train Loss: 0.0457 | Train Accuracy: 0.9874


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:09:05# Validation Loss: 0.0960 | Validation Accuracy: 0.9763



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:11:35# total batches: 845
02/04/2024, 09:11:35# Epoch 44 | Train Loss: 0.0454 | Train Accuracy: 0.9874


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:11:50# Validation Loss: 0.0845 | Validation Accuracy: 0.9794



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:14:19# total batches: 845
02/04/2024, 09:14:19# Epoch 45 | Train Loss: 0.0452 | Train Accuracy: 0.9874


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:14:34# Validation Loss: 0.1173 | Validation Accuracy: 0.9735



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:17:03# total batches: 845
02/04/2024, 09:17:03# Epoch 46 | Train Loss: 0.0754 | Train Accuracy: 0.9834


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:17:20# Validation Loss: 0.0888 | Validation Accuracy: 0.9779



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:19:55# total batches: 845
02/04/2024, 09:19:55# Epoch 47 | Train Loss: 0.0483 | Train Accuracy: 0.9873


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:20:10# Validation Loss: 0.0951 | Validation Accuracy: 0.9780



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:22:45# total batches: 845
02/04/2024, 09:22:45# Epoch 48 | Train Loss: 0.0419 | Train Accuracy: 0.9885


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:23:00# Validation Loss: 0.0977 | Validation Accuracy: 0.9754



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:25:33# total batches: 845
02/04/2024, 09:25:33# Epoch 49 | Train Loss: 0.0408 | Train Accuracy: 0.9889


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:25:49# Validation Loss: 0.0869 | Validation Accuracy: 0.9795



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:28:26# total batches: 845
02/04/2024, 09:28:26# Epoch 50 | Train Loss: 0.0566 | Train Accuracy: 0.9859


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:28:42# Validation Loss: 0.1025 | Validation Accuracy: 0.9772



Training:   0%|          | 0/845 [00:00<?, ?it/s]

02/04/2024, 09:31:11# total batches: 845
02/04/2024, 09:31:11# Epoch 51 | Train Loss: 0.0404 | Train Accuracy: 0.9887


Validation:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:31:27# Validation Loss: 0.0858 | Validation Accuracy: 0.9799



### test of valid and test part is ``graph``

- 60 APs in training x 10000times
- 5 APs in validation x 4 times
- 3 APs in test x 4 times
- Batch size = 4

In [32]:
# load the pretrained model
# pretrained_model_path = './checkpoint_graphSAGE/best_model_GraphSAGE_transE_50.pt'
model.load_state_dict(torch.load(best_model_path))

model.to(device)
model.eval()

total = 0
correct = 0
count = 0

true_labels = []
predicted_labels = []

with torch.no_grad():
    for batched_g in tqdm(dataloaders['test'], desc="Testing", position=0, leave=True):
#         print(f"data:{data[1]}")
        loss, accuracy, predicted = model_fn(batched_g, model, criterion, device, count, which_type='test')
        labels = batched_g.edata['label'].to(device)
        
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())
        
        if count % 5000 == 0:
            add_log_msg(f"labels: {labels} {labels.shape}")
            add_log_msg(f"predicted: {predicted} {predicted.shape}")
            
        count += 1
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

add_log_msg(f'Test Accuracy: {100 * correct / total} %\n\n\n')

Testing:   0%|          | 0/106 [00:00<?, ?it/s]

02/04/2024, 09:31:28# labels: tensor([248, 248, 248,  ..., 248, 248, 248], device='cuda:0') torch.Size([6003])
02/04/2024, 09:31:28# predicted: tensor([248, 248, 248,  ..., 248, 248, 248], device='cuda:0') torch.Size([6003])
02/04/2024, 09:31:46# Test Accuracy: 98.1253708898038 %





In [33]:
report_data = classification_report(true_labels, predicted_labels, output_dict=True)
report_df = pd.DataFrame(report_data).transpose()

output_path = "./result"
if not os.path.isdir(output_path):
    os.makedirs(output_path)
    
report_df.reset_index(inplace=True, names='label')

label_list = []
for idx, row in report_df.iterrows():
    if row["label"].isdigit():
        row["label"] = index2label[int(row["label"])]
    label_list.append(row["label"])
report_df["label"] = label_list

report_df.to_csv(f'{output_path}/result_{embedding}-small_batchsize-bigdim.csv', index=False)
print("report output at: ", f'{output_path}/result_{embedding}.csv')

report_df

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


report output at:  ./result/result_transH_256.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,label,precision,recall,f1-score,support
0,T1059.001_702bfdd2-9947-4eda-b551-c3a1ea9a59a2_B,0.111111,0.187500,0.139535,16.000000
1,T1078.001_d0ca00832890baa1d42322cf70fcab1a_B,1.000000,1.000000,1.000000,13.000000
2,T1074.001_e6dfc7e89359ac6fa6de84b0e1d5762e_B,0.250000,0.040000,0.068966,25.000000
3,T1491_68235976-2404-42a8-9105-68230cfef562_B,0.000000,0.000000,0.000000,18.000000
4,T1016_14a21534-350f-4d83-9dd7-3c56b93a0c17_B,1.000000,1.000000,1.000000,27.000000
...,...,...,...,...,...
276,T1003.003_9f73269695e54311dd61dc68940fb3e1_B,1.000000,1.000000,1.000000,13.000000
277,T1547.001_163b023f43aba758d36f524d146cb8ea_B,0.161290,0.200000,0.178571,25.000000
278,accuracy,0.981254,0.981254,0.981254,0.981254
279,macro avg,0.641538,0.640099,0.633599,813921.000000


### Training

- Fix the seed and save the model.state_dict that contains the initial weight

In [36]:
seed = 8787
same_seeds(seed)

model = Model(in_features=50, hidden_features=64, out_features=128, num_classes=167)
torch.save(model.state_dict(), 'model3_initial(graphsage)/initial_weight.pth')

In [69]:
# model.layer1.fc_self.weight
model.sage.layer1.fc_self.weight

Parameter containing:
tensor([[ 0.0181, -0.0857,  0.1973,  ...,  0.2417,  0.2702, -0.3041],
        [-0.0768, -0.2723, -0.2001,  ...,  0.2989, -0.1387, -0.1940],
        [ 0.2582, -0.0822,  0.3086,  ..., -0.0257, -0.1119, -0.0335],
        ...,
        [ 0.2274, -0.0411, -0.0334,  ..., -0.1679,  0.2455,  0.2424],
        [ 0.1375,  0.2813,  0.0775,  ...,  0.1337,  0.2065,  0.2618],
        [-0.0951,  0.1010, -0.2586,  ..., -0.1242, -0.0631,  0.0924]],
       requires_grad=True)

- Check if model really load the model_dict

In [70]:
model = Model(in_features=50, hidden_features=64, out_features=128, num_classes=167)
model.load_state_dict(torch.load('model3_initial(graphsage)/initial_weight.pth'))
model.sage.layer1.fc_self.weight

Parameter containing:
tensor([[ 0.0181, -0.0857,  0.1973,  ...,  0.2417,  0.2702, -0.3041],
        [-0.0768, -0.2723, -0.2001,  ...,  0.2989, -0.1387, -0.1940],
        [ 0.2582, -0.0822,  0.3086,  ..., -0.0257, -0.1119, -0.0335],
        ...,
        [ 0.2274, -0.0411, -0.0334,  ..., -0.1679,  0.2455,  0.2424],
        [ 0.1375,  0.2813,  0.0775,  ...,  0.1337,  0.2065,  0.2618],
        [-0.0951,  0.1010, -0.2586,  ..., -0.1242, -0.0631,  0.0924]],
       requires_grad=True)