# Test of GCN
- use DGL

In [1]:
import dgl
import json
import torch
import subprocess
import torch as th
from tqdm import tqdm
from tqdm.notebook import tqdm
import torch.nn as nn
from dgl.nn import GraphConv, GATConv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

- check the GPU and assign the GPU

In [2]:
def get_free_gpu():
    try:
        # Run nvidia-smi command to get GPU details
        _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
        command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
        memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
        memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
        # Get the GPU with the maximum free memory
        best_gpu_id = memory_free_values.index(max(memory_free_values))
        return best_gpu_id
    except:
        # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
        return 0

if torch.cuda.is_available():
    # Get the best GPU ID based on free memory and set it
    best_gpu_id = get_free_gpu()
    device = torch.device(f"cuda:{best_gpu_id}")
else:
    device = torch.device("cpu")

print(device)


cuda:2


In [30]:
import numpy as np
import torch
import random

#fix seed
def same_seeds(seed = 8787):
    torch.manual_seed(seed)
    # random.seed(seed) 
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Data Loader

In [15]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]

        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"]).to(self.device)

        g.ndata['feat'] = th.tensor(data["node_feat"]).to(self.device)
        g.edata['feat'] = th.tensor(data["edge_attr"]).to(self.device)  # Add edge features to graph

        return g, th.tensor(data["label"]).to(self.device)


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


In [16]:
datasets = ['train', 'valid', 'test']
dataset_data = {}

# 1. 加載datasets
for dataset_name in tqdm(datasets):
#     file_path = f"../../data_processing/dgl/data/test_graph/repeated_{dataset_name}.jsonl"
#     file_path = f"../../data_processing/dgl/data/test_triplet/repeated_test_{dataset_name}.jsonl"
    file_path = f"../../data_processing/dgl/data_new/exp1-2/training_data/exp_2/transH_50/{dataset_name}.jsonl"
    
    print(file_path)
    with open(file_path) as f:
        data_list = [json.loads(line) for line in tqdm(f, position=0, leave=True)]
    
    dataset_data[dataset_name] = GraphDataset(data_list, device)

print("Datasets loaded!")

  0%|          | 0/3 [00:00<?, ?it/s]

../../data_processing/dgl/data_new/exp1-2/training_data/exp_2/transH_50/train.jsonl


0it [00:00, ?it/s]

../../data_processing/dgl/data_new/exp1-2/training_data/exp_2/transH_50/valid.jsonl


0it [00:00, ?it/s]

../../data_processing/dgl/data_new/exp1-2/training_data/exp_2/transH_50/test.jsonl


0it [00:00, ?it/s]

Datasets loaded!


In [17]:
def create_dataloaders(batch_size, shuffle=True):
    dataloaders = {}
    for dataset_name, dataset in dataset_data.items():
        # do not shuffle the testing dataset
        if dataset_name == "test":
            dataloaders[dataset_name] = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)    
        else:
            dataloaders[dataset_name] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate)
    return dataloaders

# dataloaders = create_dataloaders(4)
dataloaders = create_dataloaders(16)

- Turn the print message to a log file

In [18]:
import datetime

now = datetime.datetime.now()

formatted_time = now.strftime("%m%d_%H:%M")

log_file_path = f"../log_message/{formatted_time}_GCN.log"

def add_log_msg(msg, log_file_path=log_file_path):
    with open(log_file_path, 'a') as f:
        f.write(f'{datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}# {msg}\n')
    print(f'{datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}# {msg}')

print(log_file_path)

../log_message/0910_14:50_GCN.log


### Model

In [25]:
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size, allow_zero_in_degree=True)
        self.conv2 = GraphConv(hidden_size, hidden_size*4, allow_zero_in_degree=True)
        self.conv3 = GraphConv(hidden_size*4, num_classes, allow_zero_in_degree=True)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = self.conv2(g, h)
        h = torch.relu(h)
        h = self.conv3(g, h)
        
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return hg
    

    
# class GCN(nn.Module):
#     def __init__(self, in_feats, hidden_size, num_classes):
#         super(GCN, self).__init__()
#         self.conv1 = GraphConv(in_feats, hidden_size)
#         self.conv2 = GraphConv(hidden_size, num_classes)

#     def forward(self, g, inputs):
#         h = self.conv1(g, inputs)
#         h = torch.relu(h)
#         h = self.conv2(g, h)
        
#         g.ndata['h'] = h
#         hg = dgl.mean_nodes(g, 'h')
#         return hg

- Model Forward

In [37]:
def model_fn(data, model, criterion, device, count, which_type="train"):
    """Forward a batch through the model."""
    batched_g, labels = data
    batched_g = batched_g.to(device)
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat']) # for GCN
#     logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT

    loss = criterion(logits, labels)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

    return loss, accuracy, preds

'''
batched_g is like: 
Graph(num_nodes=96, num_edges=160, ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.int64)}, edata_schemes={})
num_nodes = 3*batch_size, num_edges = 5*batch_size

labels is like: tensor([ 76,   0,   0,   0,   0,   0,   0,   0,   0,  76,   0,  76,   0,   0,
                          0,   0,  76,   0,  30,  92,   0,   0,  76,   0,   0,   0,   0,   0,
                        116,   0,  76,  76])
'''

"\nbatched_g is like: \nGraph(num_nodes=96, num_edges=160, ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.int64)}, edata_schemes={})\nnum_nodes = 3*batch_size, num_edges = 5*batch_size\n\nlabels is like: tensor([ 76,   0,   0,   0,   0,   0,   0,   0,   0,  76,   0,  76,   0,   0,\n                          0,   0,  76,   0,  30,  92,   0,   0,  76,   0,   0,   0,   0,   0,\n                        116,   0,  76,  76])\n"

### Training

In [43]:
import os
import csv
import pandas as pd
from sklearn.metrics import classification_report
from torch.optim import AdamW, lr_scheduler

seed = 8787
same_seeds(seed)

model = GCN(50,16,167)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
# model.load_state_dict(torch.load('model1_initial/initial_weight.pth'))
best_model_path = "../checkpoint_GCN/best_model_GCN_transH_50.pt"

model = model.to(device)

# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
optimizer = AdamW(model.parameters(), lr=5e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=18, num_training_steps=total_steps)

# T_max control the period of the lr changing -> set 1/10 first
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=36, eta_min=0, last_epoch=- 1, verbose=False)


criterion = nn.CrossEntropyLoss()
total_steps = 400

# save the best model
best_val_loss = float('inf')
patience = 10  # Number of epochs with no improvement after which training will be stopped.
waiting = 0  # The number of epochs with no improvement so far.


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    for data in tqdm(dataloaders['train'], desc="Training", position=0, leave=True):
        num_batches += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, num_batches, which_type='train')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

        
    scheduler.step()
    add_log_msg(f"total batches: {num_batches}")

    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    add_log_msg(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    
    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0


    with torch.no_grad():
        for data in tqdm(dataloaders['valid'], desc="Validation", position=0, leave=True):
            loss, accuracy, _ = model_fn(data, model, criterion, device, num_batches, which_type='validation')
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    current_loss = total_loss / num_batches
    
    add_log_msg(f'Validation Loss: {current_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}\n')
    
            
    if current_loss < best_val_loss:
        best_val_loss = current_loss
        waiting = 0
        
        if os.path.exists(best_model_path):
            os.remove(best_model_path)
            add_log_msg("Find a better model!!")

        torch.save(model.state_dict(), best_model_path)

        
#         print(best_model_path)

    else:
        waiting += 1
        if waiting >= patience:
            add_log_msg("============================== Early stopping ==================================")
            break

  0%|          | 0/400 [00:00<?, ?it/s]

Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:40:51# total batches: 8300
09/10/2023, 20:40:51# Epoch 0 | Train Loss: 4.6298 | Train Accuracy: 0.0134


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:41:07# Validation Loss: 4.4741 | Validation Accuracy: 0.0301

09/10/2023, 20:41:07# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:43:43# total batches: 8300
09/10/2023, 20:43:43# Epoch 1 | Train Loss: 4.3843 | Train Accuracy: 0.0327


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:43:59# Validation Loss: 4.3011 | Validation Accuracy: 0.0361

09/10/2023, 20:43:59# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:46:45# total batches: 8300
09/10/2023, 20:46:45# Epoch 2 | Train Loss: 4.2321 | Train Accuracy: 0.0463


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:47:03# Validation Loss: 4.1668 | Validation Accuracy: 0.0603

09/10/2023, 20:47:03# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:49:46# total batches: 8300
09/10/2023, 20:49:46# Epoch 3 | Train Loss: 4.1242 | Train Accuracy: 0.0613


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:50:02# Validation Loss: 4.0830 | Validation Accuracy: 0.0663

09/10/2023, 20:50:02# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:52:42# total batches: 8300
09/10/2023, 20:52:42# Epoch 4 | Train Loss: 4.0540 | Train Accuracy: 0.0826


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:52:58# Validation Loss: 4.0250 | Validation Accuracy: 0.0904

09/10/2023, 20:52:58# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:55:30# total batches: 8300
09/10/2023, 20:55:30# Epoch 5 | Train Loss: 4.0097 | Train Accuracy: 0.0935


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:55:46# Validation Loss: 3.9917 | Validation Accuracy: 0.1073

09/10/2023, 20:55:46# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 20:58:21# total batches: 8300
09/10/2023, 20:58:21# Epoch 6 | Train Loss: 3.9798 | Train Accuracy: 0.0996


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 20:58:37# Validation Loss: 3.9650 | Validation Accuracy: 0.1072

09/10/2023, 20:58:37# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:01:17# total batches: 8300
09/10/2023, 21:01:17# Epoch 7 | Train Loss: 3.9582 | Train Accuracy: 0.1080


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:01:35# Validation Loss: 3.9548 | Validation Accuracy: 0.0959

09/10/2023, 21:01:35# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:04:18# total batches: 8300
09/10/2023, 21:04:18# Epoch 8 | Train Loss: 3.9420 | Train Accuracy: 0.1109


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:04:36# Validation Loss: 3.9307 | Validation Accuracy: 0.1084

09/10/2023, 21:04:36# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:07:05# total batches: 8300
09/10/2023, 21:07:05# Epoch 9 | Train Loss: 3.9281 | Train Accuracy: 0.1149


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:07:21# Validation Loss: 3.9229 | Validation Accuracy: 0.1205

09/10/2023, 21:07:21# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:10:13# total batches: 8300
09/10/2023, 21:10:13# Epoch 10 | Train Loss: 3.9180 | Train Accuracy: 0.1176


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:10:31# Validation Loss: 3.9075 | Validation Accuracy: 0.1205

09/10/2023, 21:10:31# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:13:18# total batches: 8300
09/10/2023, 21:13:18# Epoch 11 | Train Loss: 3.9093 | Train Accuracy: 0.1194


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:13:35# Validation Loss: 3.9082 | Validation Accuracy: 0.1264



Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:16:27# total batches: 8300
09/10/2023, 21:16:27# Epoch 12 | Train Loss: 3.9017 | Train Accuracy: 0.1217


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:16:45# Validation Loss: 3.8987 | Validation Accuracy: 0.1445

09/10/2023, 21:16:45# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:19:42# total batches: 8300
09/10/2023, 21:19:42# Epoch 13 | Train Loss: 3.8947 | Train Accuracy: 0.1237


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:20:02# Validation Loss: 3.8990 | Validation Accuracy: 0.1145



Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:22:51# total batches: 8300
09/10/2023, 21:22:51# Epoch 14 | Train Loss: 3.8892 | Train Accuracy: 0.1239


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:23:09# Validation Loss: 3.8848 | Validation Accuracy: 0.1326

09/10/2023, 21:23:09# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:26:25# total batches: 8300
09/10/2023, 21:26:25# Epoch 15 | Train Loss: 3.8830 | Train Accuracy: 0.1293


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:26:50# Validation Loss: 3.8839 | Validation Accuracy: 0.1325

09/10/2023, 21:26:50# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:30:43# total batches: 8300
09/10/2023, 21:30:43# Epoch 16 | Train Loss: 3.8757 | Train Accuracy: 0.1358


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:31:09# Validation Loss: 3.8728 | Validation Accuracy: 0.1325

09/10/2023, 21:31:09# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:35:02# total batches: 8300
09/10/2023, 21:35:02# Epoch 17 | Train Loss: 3.8688 | Train Accuracy: 0.1338


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:35:28# Validation Loss: 3.8618 | Validation Accuracy: 0.1325

09/10/2023, 21:35:28# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:39:23# total batches: 8300
09/10/2023, 21:39:23# Epoch 18 | Train Loss: 3.8621 | Train Accuracy: 0.1342


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:39:49# Validation Loss: 3.8568 | Validation Accuracy: 0.1325

09/10/2023, 21:39:49# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:44:47# total batches: 8300
09/10/2023, 21:44:47# Epoch 19 | Train Loss: 3.8554 | Train Accuracy: 0.1346


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:45:12# Validation Loss: 3.8532 | Validation Accuracy: 0.1447

09/10/2023, 21:45:12# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:49:03# total batches: 8300
09/10/2023, 21:49:03# Epoch 20 | Train Loss: 3.8496 | Train Accuracy: 0.1346


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:49:19# Validation Loss: 3.8440 | Validation Accuracy: 0.1325

09/10/2023, 21:49:19# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:52:12# total batches: 8300
09/10/2023, 21:52:12# Epoch 21 | Train Loss: 3.8438 | Train Accuracy: 0.1349


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:52:29# Validation Loss: 3.8362 | Validation Accuracy: 0.1567

09/10/2023, 21:52:29# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:55:28# total batches: 8300
09/10/2023, 21:55:28# Epoch 22 | Train Loss: 3.8388 | Train Accuracy: 0.1379


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:55:46# Validation Loss: 3.8344 | Validation Accuracy: 0.1507

09/10/2023, 21:55:46# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 21:58:51# total batches: 8300
09/10/2023, 21:58:51# Epoch 23 | Train Loss: 3.8340 | Train Accuracy: 0.1373


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 21:59:08# Validation Loss: 3.8326 | Validation Accuracy: 0.1385

09/10/2023, 21:59:08# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:02:14# total batches: 8300
09/10/2023, 22:02:14# Epoch 24 | Train Loss: 3.8300 | Train Accuracy: 0.1368


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:02:32# Validation Loss: 3.8286 | Validation Accuracy: 0.1325

09/10/2023, 22:02:32# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:05:40# total batches: 8300
09/10/2023, 22:05:40# Epoch 25 | Train Loss: 3.8264 | Train Accuracy: 0.1395


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:06:01# Validation Loss: 3.8225 | Validation Accuracy: 0.1446

09/10/2023, 22:06:01# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:09:05# total batches: 8300
09/10/2023, 22:09:05# Epoch 26 | Train Loss: 3.8232 | Train Accuracy: 0.1398


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:09:23# Validation Loss: 3.8200 | Validation Accuracy: 0.1505

09/10/2023, 22:09:23# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:12:31# total batches: 8300
09/10/2023, 22:12:31# Epoch 27 | Train Loss: 3.8206 | Train Accuracy: 0.1398


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:12:49# Validation Loss: 3.8192 | Validation Accuracy: 0.1385

09/10/2023, 22:12:49# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:15:50# total batches: 8300
09/10/2023, 22:15:50# Epoch 28 | Train Loss: 3.8180 | Train Accuracy: 0.1427


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:16:08# Validation Loss: 3.8156 | Validation Accuracy: 0.1445

09/10/2023, 22:16:08# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:19:08# total batches: 8300
09/10/2023, 22:19:08# Epoch 29 | Train Loss: 3.8163 | Train Accuracy: 0.1471


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:19:25# Validation Loss: 3.8132 | Validation Accuracy: 0.1627

09/10/2023, 22:19:25# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:22:31# total batches: 8300
09/10/2023, 22:22:31# Epoch 30 | Train Loss: 3.8147 | Train Accuracy: 0.1480


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:22:50# Validation Loss: 3.8127 | Validation Accuracy: 0.1506

09/10/2023, 22:22:50# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:25:52# total batches: 8300
09/10/2023, 22:25:52# Epoch 31 | Train Loss: 3.8136 | Train Accuracy: 0.1521


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:26:10# Validation Loss: 3.8116 | Validation Accuracy: 0.1626

09/10/2023, 22:26:10# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:29:13# total batches: 8300
09/10/2023, 22:29:13# Epoch 32 | Train Loss: 3.8127 | Train Accuracy: 0.1566


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:29:31# Validation Loss: 3.8112 | Validation Accuracy: 0.1626

09/10/2023, 22:29:31# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:32:40# total batches: 8300
09/10/2023, 22:32:40# Epoch 33 | Train Loss: 3.8121 | Train Accuracy: 0.1619


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:32:58# Validation Loss: 3.8103 | Validation Accuracy: 0.1626

09/10/2023, 22:32:58# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:35:40# total batches: 8300
09/10/2023, 22:35:40# Epoch 34 | Train Loss: 3.8118 | Train Accuracy: 0.1634


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:35:56# Validation Loss: 3.8101 | Validation Accuracy: 0.1627

09/10/2023, 22:35:56# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:38:32# total batches: 8300
09/10/2023, 22:38:32# Epoch 35 | Train Loss: 3.8115 | Train Accuracy: 0.1578


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:38:49# Validation Loss: 3.8101 | Validation Accuracy: 0.1626

09/10/2023, 22:38:49# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:41:31# total batches: 8300
09/10/2023, 22:41:31# Epoch 36 | Train Loss: 3.8114 | Train Accuracy: 0.1626


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:41:47# Validation Loss: 3.8100 | Validation Accuracy: 0.1626

09/10/2023, 22:41:47# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

09/10/2023, 22:44:31# total batches: 8300
09/10/2023, 22:44:31# Epoch 37 | Train Loss: 3.8114 | Train Accuracy: 0.1627


Validation:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:44:49# Validation Loss: 3.8098 | Validation Accuracy: 0.1627

09/10/2023, 22:44:49# Find a better model!!


Training:   0%|          | 0/8300 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# load the pretrained model
pretrained_model_path = '../checkpoint_GCN/best_model_GCN_transH_50.pt'
model.load_state_dict(torch.load(pretrained_model_path))

model.to(device)
model.eval()

total = 0
correct = 0
count = 0

true_labels = []
predicted_labels = []

with torch.no_grad():
    for data in tqdm(dataloaders['test'], desc="Testing", position=0, leave=True):
#         print(f"data:{data[1]}")
        loss, accuracy, predicted = model_fn(data, model, criterion, device, count, which_type='test')
        labels = data[1].to(device)
        
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())
        
        if count % 5000 == 0:
            add_log_msg(f"labels: {labels} {labels.shape}")
            add_log_msg(f"predicted: {predicted} {predicted.shape}")
            
        count += 1
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

add_log_msg(f'Test Accuracy: {100 * correct / total} %\n\n\n')


# ======================================== handlig the output excel files ========================================
mapping_file = './new_mapping.txt'
label_mapping = {}
with open(mapping_file, 'r') as f:
    for line in f:
        parts = line.strip().split(': ')
        label_mapping[int(parts[1])] = parts[0]
        
# 将映射后的标签应用到true和predicted标签列表
mapped_true_labels = [label_mapping[label] for label in true_labels]
mapped_predicted_labels = [label_mapping[label] for label in predicted_labels]

# 生成Scikit-learn报告信息的DataFrame
report_data = classification_report(mapped_true_labels, mapped_predicted_labels, output_dict=True)
report_df = pd.DataFrame(report_data).transpose()

# mapped_true_labels_np = np.array(mapped_true_labels)
# mapped_predicted_labels_np = np.array(mapped_predicted_labels)

# print("mapped_true_labels 的形状:", mapped_true_labels_np.shape)
# print("mapped_predicted_labels 的形状:", mapped_predicted_labels_np.shape)

report_folder = 'classification_report'
os.makedirs(report_folder, exist_ok=True)

count = 0
while True:
    report_filename = f'classification_report-transH_50-GCN-{count}.xlsx'
    labels_filename = f'mapped_true_predicted_labels-transH_50-GCN-{count}.xlsx'
    
    report_path = os.path.join(report_folder, report_filename)
    labels_path = os.path.join(report_folder, labels_filename)
    
    if not os.path.exists(report_path) and not os.path.exists(labels_path):
        break
    count += 1

    
report_df.to_excel(report_path, index_label='Label')

mapped_labels_df = pd.DataFrame({'true_label': mapped_true_labels, 'predicted_label': mapped_predicted_labels})
mapped_labels_df.to_excel(labels_path, index=False)

add_log_msg(f"report path: {report_path}")
add_log_msg(f"label path: {labels_path}")

mapped_report = classification_report(mapped_true_labels, mapped_predicted_labels)
add_log_msg(f"mapped_report:\n{mapped_report}")

Testing:   0%|          | 0/1038 [00:00<?, ?it/s]

09/10/2023, 22:45:45# labels: tensor([65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65],
       device='cuda:2') torch.Size([16])
09/10/2023, 22:45:45# predicted: tensor([92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92],
       device='cuda:2') torch.Size([16])
