# Test of GCN
- use DGL

In [13]:
import dgl
import json
import torch
import torch as th
from tqdm import tqdm
import torch.nn as nn
from dgl.nn import GraphConv, GATConv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

- check the GPU and assign the GPU

In [16]:
import subprocess
import torch

def get_free_gpu():
    try:
        # Run nvidia-smi command to get GPU details
        _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
        command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
        memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
        memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
        # Get the GPU with the maximum free memory
        best_gpu_id = memory_free_values.index(max(memory_free_values))
        return best_gpu_id
    except:
        # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
        return 0

if torch.cuda.is_available():
    # Get the best GPU ID based on free memory and set it
    best_gpu_id = get_free_gpu()
    device = torch.device(f"cuda:{best_gpu_id}")
else:
    device = torch.device("cpu")

print(device)


cuda:2


## Data Loader

In [5]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"]).to(self.device)
        g = dgl.add_self_loop(g) # this would cause the num_edges = 2+3 in each sub_graph
        g.ndata['feat'] = th.tensor(data["node_feat"]).to(self.device)
        return g, th.tensor(data["y"]).to(self.device)


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

In [6]:
datasets = ['train', 'valid', 'test']
dataloaders = {}

for dataset_name in tqdm(datasets):
#     file_path = f"../data/final_small_version/remaining_{dataset_name}.jsonl"
    file_path = f"../data/training_data/repeated_{dataset_name}.jsonl"
    print(file_path)
    with open(file_path) as f:
        data_list = [json.loads(line) for line in f]
    
    dataset = GraphDataset(data_list, device)
    dataloaders[dataset_name] = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate)
    
print("Done!")

  0%|          | 0/3 [00:00<?, ?it/s]

../data/final_small_version/remaining_train.jsonl


100%|██████████| 3/3 [00:00<00:00,  4.42it/s]

../data/final_small_version/remaining_valid.jsonl
../data/final_small_version/remaining_test.jsonl


100%|██████████| 3/3 [00:00<00:00,  3.98it/s]

Done!





### Model

In [7]:
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, hidden_size*4)
        self.conv3 = GraphConv(hidden_size*4, num_classes)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = self.conv2(g, h)
        h = torch.relu(h)
        h = self.conv3(g, h)
        
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return hg
    
# class GCN(nn.Module):
#     def __init__(self, in_feats, hidden_size, num_classes):
#         super(GCN, self).__init__()
#         self.conv1 = GraphConv(in_feats, hidden_size)
#         self.conv2 = GraphConv(hidden_size, num_classes)

#     def forward(self, g, inputs):
#         h = self.conv1(g, inputs)
#         h = torch.relu(h)
#         h = self.conv2(g, h)
        
#         g.ndata['h'] = h
#         hg = dgl.mean_nodes(g, 'h')
#         return hg

- Model Forward

In [8]:
def model_fn(data, model, criterion, device):
    """Forward a batch through the model."""
    batched_g, labels = data
    batched_g = batched_g.to(device)
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat']) # for GCN
#     logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT

    loss = criterion(logits, labels)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

    return loss, accuracy

'''
batched_g is like: 
Graph(num_nodes=96, num_edges=160, ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.int64)}, edata_schemes={})
num_nodes = 3*batch_size, num_edges = 5*batch_size

labels is like: tensor([ 76,   0,   0,   0,   0,   0,   0,   0,   0,  76,   0,  76,   0,   0,
                          0,   0,  76,   0,  30,  92,   0,   0,  76,   0,   0,   0,   0,   0,
                        116,   0,  76,  76])
'''

"\nbatched_g is like: \nGraph(num_nodes=96, num_edges=160, ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.int64)}, edata_schemes={})\nnum_nodes = 3*batch_size, num_edges = 5*batch_size\n\nlabels is like: tensor([ 76,   0,   0,   0,   0,   0,   0,   0,   0,  76,   0,  76,   0,   0,\n                          0,   0,  76,   0,  30,  92,   0,   0,  76,   0,   0,   0,   0,   0,\n                        116,   0,  76,  76])\n"

### Training

In [12]:
model = GCN(1, 16, 168) # 1是輸入特徵的維度，16是隱藏層大小，168是類別數量

# model = GAT(in_dim=1, hidden_dim=16, out_dim=168, num_heads=4)
# in_dim means the dimension of the node_feat(1 dim, since the design of our dataset), if a node has multiple feature -> in_dim > 1
# out_dim means the # of the categories -> 168 for out tasks


model = model.to(device)
print(f"Using: {device} now")

criterion = nn.CrossEntropyLoss()
total_steps = 120

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=30, num_training_steps=total_steps)



for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    for data in dataloaders['train']:
#         print(data)
        loss, accuracy = model_fn(data, model, criterion, device)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validate
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, f"../checkpoint_GCN/checkpoint_{epoch}.pt")
    

# After all epochs
model.eval()
total = 0
correct = 0
with torch.no_grad():
#     for batched_g, labels in test_dataloader:
    for batched_g, labels in dataloaders['test']:
        batched_g = batched_g.to(device)
        labels = labels.to(device)
        logits = model(batched_g, batched_g.ndata['feat'])
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

Using: cuda:0 now


  0%|          | 0/120 [00:00<?, ?it/s]

Epoch 0 | Train Loss: 2625.5943 | Train Accuracy: 0.4763


  1%|          | 1/120 [00:56<1:52:21, 56.65s/it]

Validation Loss: 494.0275 | Validation Accuracy: 0.6642
Epoch 1 | Train Loss: 600.2995 | Train Accuracy: 0.4953


  2%|▏         | 2/120 [01:51<1:49:07, 55.49s/it]

Validation Loss: 606.5858 | Validation Accuracy: 0.6642
Epoch 2 | Train Loss: 488.8606 | Train Accuracy: 0.4949


  2%|▎         | 3/120 [02:47<1:48:36, 55.69s/it]

Validation Loss: 616.4322 | Validation Accuracy: 0.0002
Epoch 3 | Train Loss: 427.2716 | Train Accuracy: 0.4907


  3%|▎         | 4/120 [03:43<1:47:46, 55.74s/it]

Validation Loss: 430.2292 | Validation Accuracy: 0.6642
Epoch 4 | Train Loss: 383.9496 | Train Accuracy: 0.4926


  4%|▍         | 5/120 [04:39<1:47:00, 55.83s/it]

Validation Loss: 315.8183 | Validation Accuracy: 0.6642
Epoch 5 | Train Loss: 344.8707 | Train Accuracy: 0.4948


  5%|▌         | 6/120 [05:34<1:45:31, 55.54s/it]

Validation Loss: 435.5971 | Validation Accuracy: 0.0002
Epoch 6 | Train Loss: 323.5571 | Train Accuracy: 0.4915


  6%|▌         | 7/120 [06:31<1:45:41, 56.12s/it]

Validation Loss: 353.6193 | Validation Accuracy: 0.6642
Epoch 7 | Train Loss: 286.5976 | Train Accuracy: 0.4899


  7%|▋         | 8/120 [07:27<1:44:53, 56.19s/it]

Validation Loss: 342.9844 | Validation Accuracy: 0.6642
Epoch 8 | Train Loss: 244.9461 | Train Accuracy: 0.4956


  8%|▊         | 9/120 [08:22<1:42:58, 55.66s/it]

Validation Loss: 248.3183 | Validation Accuracy: 0.6642
Epoch 9 | Train Loss: 248.4652 | Train Accuracy: 0.4942


  8%|▊         | 10/120 [09:19<1:43:09, 56.26s/it]

Validation Loss: 226.0744 | Validation Accuracy: 0.6642
Epoch 10 | Train Loss: 211.6099 | Train Accuracy: 0.4956


  9%|▉         | 11/120 [10:16<1:42:43, 56.54s/it]

Validation Loss: 165.2042 | Validation Accuracy: 0.6642
Epoch 11 | Train Loss: 180.7896 | Train Accuracy: 0.4968


 10%|█         | 12/120 [11:11<1:40:55, 56.07s/it]

Validation Loss: 204.9502 | Validation Accuracy: 0.6642
Epoch 12 | Train Loss: 175.7802 | Train Accuracy: 0.4944


 11%|█         | 13/120 [12:12<1:42:21, 57.39s/it]

Validation Loss: 126.9200 | Validation Accuracy: 0.3611
Epoch 13 | Train Loss: 144.8497 | Train Accuracy: 0.4918


 12%|█▏        | 14/120 [13:08<1:40:33, 56.92s/it]

Validation Loss: 207.7641 | Validation Accuracy: 0.2769
Epoch 14 | Train Loss: 129.5772 | Train Accuracy: 0.4924


 12%|█▎        | 15/120 [14:06<1:40:15, 57.29s/it]

Validation Loss: 140.3221 | Validation Accuracy: 0.6642
Epoch 15 | Train Loss: 123.7912 | Train Accuracy: 0.4959


 13%|█▎        | 16/120 [15:01<1:38:11, 56.65s/it]

Validation Loss: 130.3518 | Validation Accuracy: 0.2770
Epoch 16 | Train Loss: 100.0037 | Train Accuracy: 0.4941


 14%|█▍        | 17/120 [15:56<1:36:27, 56.19s/it]

Validation Loss: 96.4111 | Validation Accuracy: 0.6642
Epoch 17 | Train Loss: 96.9231 | Train Accuracy: 0.4951


 15%|█▌        | 18/120 [16:50<1:34:10, 55.40s/it]

Validation Loss: 79.8914 | Validation Accuracy: 0.6487
Epoch 18 | Train Loss: 86.4256 | Train Accuracy: 0.4992


 16%|█▌        | 19/120 [17:45<1:33:16, 55.41s/it]

Validation Loss: 114.5397 | Validation Accuracy: 0.6642
Epoch 19 | Train Loss: 84.5219 | Train Accuracy: 0.4969


 17%|█▋        | 20/120 [18:40<1:31:52, 55.12s/it]

Validation Loss: 107.8029 | Validation Accuracy: 0.6642
Epoch 20 | Train Loss: 66.3887 | Train Accuracy: 0.5054


 18%|█▊        | 21/120 [19:36<1:31:49, 55.65s/it]

Validation Loss: 73.4430 | Validation Accuracy: 0.6642
Epoch 21 | Train Loss: 58.2606 | Train Accuracy: 0.5012


 18%|█▊        | 22/120 [20:31<1:30:08, 55.18s/it]

Validation Loss: 46.4625 | Validation Accuracy: 0.6642
Epoch 22 | Train Loss: 46.7799 | Train Accuracy: 0.5066


 19%|█▉        | 23/120 [21:25<1:28:53, 54.99s/it]

Validation Loss: 46.3706 | Validation Accuracy: 0.6642
Epoch 23 | Train Loss: 38.5775 | Train Accuracy: 0.5117


 20%|██        | 24/120 [22:21<1:28:15, 55.16s/it]

Validation Loss: 63.5831 | Validation Accuracy: 0.2064
Epoch 24 | Train Loss: 30.2775 | Train Accuracy: 0.5125


 21%|██        | 25/120 [23:15<1:27:08, 55.03s/it]

Validation Loss: 26.2004 | Validation Accuracy: 0.6642
Epoch 25 | Train Loss: 26.5093 | Train Accuracy: 0.5144


 22%|██▏       | 26/120 [24:15<1:28:25, 56.44s/it]

Validation Loss: 33.4090 | Validation Accuracy: 0.2826
Epoch 26 | Train Loss: 22.4150 | Train Accuracy: 0.5117


 22%|██▎       | 27/120 [25:09<1:26:30, 55.81s/it]

Validation Loss: 12.8802 | Validation Accuracy: 0.5664
Epoch 27 | Train Loss: 15.2368 | Train Accuracy: 0.5288


 23%|██▎       | 28/120 [26:06<1:25:47, 55.95s/it]

Validation Loss: 11.6570 | Validation Accuracy: 0.6263
Epoch 28 | Train Loss: 11.2065 | Train Accuracy: 0.5371


 24%|██▍       | 29/120 [27:02<1:25:02, 56.07s/it]

Validation Loss: 14.6924 | Validation Accuracy: 0.6642
Epoch 29 | Train Loss: 7.1230 | Train Accuracy: 0.5535


 25%|██▌       | 30/120 [27:57<1:23:25, 55.62s/it]

Validation Loss: 3.1240 | Validation Accuracy: 0.6655
Epoch 30 | Train Loss: 3.5164 | Train Accuracy: 0.5874


 26%|██▌       | 31/120 [28:52<1:22:30, 55.62s/it]

Validation Loss: 4.9438 | Validation Accuracy: 0.6642
Epoch 31 | Train Loss: 3.4609 | Train Accuracy: 0.5925


 27%|██▋       | 32/120 [29:50<1:22:31, 56.27s/it]

Validation Loss: 7.9241 | Validation Accuracy: 0.4156
Epoch 32 | Train Loss: 2.7029 | Train Accuracy: 0.6088


 28%|██▊       | 33/120 [30:46<1:21:15, 56.04s/it]

Validation Loss: 12.2092 | Validation Accuracy: 0.0133
Epoch 33 | Train Loss: 2.9257 | Train Accuracy: 0.6006


 28%|██▊       | 34/120 [31:40<1:19:49, 55.69s/it]

Validation Loss: 6.2106 | Validation Accuracy: 0.6642
Epoch 34 | Train Loss: 1.8995 | Train Accuracy: 0.6295


 29%|██▉       | 35/120 [32:38<1:19:47, 56.32s/it]

Validation Loss: 2.3873 | Validation Accuracy: 0.6280
Epoch 35 | Train Loss: 2.0635 | Train Accuracy: 0.6334


 30%|███       | 36/120 [33:37<1:20:02, 57.17s/it]

Validation Loss: 2.8139 | Validation Accuracy: 0.6642
Epoch 36 | Train Loss: 1.9470 | Train Accuracy: 0.6313


 31%|███       | 37/120 [34:33<1:18:26, 56.70s/it]

Validation Loss: 1.3942 | Validation Accuracy: 0.6070
Epoch 37 | Train Loss: 1.8056 | Train Accuracy: 0.6307


 32%|███▏      | 38/120 [35:29<1:17:03, 56.38s/it]

Validation Loss: 2.8793 | Validation Accuracy: 0.6642
Epoch 38 | Train Loss: 1.8070 | Train Accuracy: 0.6271


 32%|███▎      | 39/120 [36:27<1:16:50, 56.92s/it]

Validation Loss: 3.8466 | Validation Accuracy: 0.6642
Epoch 39 | Train Loss: 1.3964 | Train Accuracy: 0.6321


 33%|███▎      | 40/120 [37:23<1:15:36, 56.70s/it]

Validation Loss: 1.3615 | Validation Accuracy: 0.5942
Epoch 40 | Train Loss: 1.7262 | Train Accuracy: 0.6272


 34%|███▍      | 41/120 [38:17<1:13:40, 55.96s/it]

Validation Loss: 7.1247 | Validation Accuracy: 0.6642
Epoch 41 | Train Loss: 1.7596 | Train Accuracy: 0.6442


 35%|███▌      | 42/120 [39:12<1:12:27, 55.74s/it]

Validation Loss: 1.4279 | Validation Accuracy: 0.6407
Epoch 42 | Train Loss: 1.4412 | Train Accuracy: 0.6327


 36%|███▌      | 43/120 [40:08<1:11:23, 55.63s/it]

Validation Loss: 1.3492 | Validation Accuracy: 0.5949
Epoch 43 | Train Loss: 1.2300 | Train Accuracy: 0.6318


 37%|███▋      | 44/120 [41:02<1:10:03, 55.31s/it]

Validation Loss: 1.3519 | Validation Accuracy: 0.6089
Epoch 44 | Train Loss: 1.7431 | Train Accuracy: 0.6307


 38%|███▊      | 45/120 [41:58<1:09:11, 55.36s/it]

Validation Loss: 1.3101 | Validation Accuracy: 0.6118
Epoch 45 | Train Loss: 1.3916 | Train Accuracy: 0.6407


 38%|███▊      | 46/120 [42:52<1:07:56, 55.08s/it]

Validation Loss: 1.3406 | Validation Accuracy: 0.6262
Epoch 46 | Train Loss: 1.5616 | Train Accuracy: 0.6305


 39%|███▉      | 47/120 [43:47<1:06:42, 54.82s/it]

Validation Loss: 1.3529 | Validation Accuracy: 0.6172
Epoch 47 | Train Loss: 2.1877 | Train Accuracy: 0.6394


 40%|████      | 48/120 [44:41<1:05:47, 54.82s/it]

Validation Loss: 2.8905 | Validation Accuracy: 0.6642
Epoch 48 | Train Loss: 2.5819 | Train Accuracy: 0.6644


 41%|████      | 49/120 [45:39<1:06:02, 55.81s/it]

Validation Loss: 2.3479 | Validation Accuracy: 0.6642
Epoch 49 | Train Loss: 2.1874 | Train Accuracy: 0.6644


 42%|████▏     | 50/120 [46:37<1:05:39, 56.28s/it]

Validation Loss: 2.0554 | Validation Accuracy: 0.6642
Epoch 50 | Train Loss: 1.9409 | Train Accuracy: 0.6644


 42%|████▎     | 51/120 [47:32<1:04:15, 55.88s/it]

Validation Loss: 1.8446 | Validation Accuracy: 0.6642
Epoch 51 | Train Loss: 1.7438 | Train Accuracy: 0.6644


 43%|████▎     | 52/120 [48:26<1:02:50, 55.44s/it]

Validation Loss: 1.6600 | Validation Accuracy: 0.6642
Epoch 52 | Train Loss: 1.5640 | Train Accuracy: 0.6644


 44%|████▍     | 53/120 [49:21<1:01:35, 55.16s/it]

Validation Loss: 1.4886 | Validation Accuracy: 0.6642
Epoch 53 | Train Loss: 1.3998 | Train Accuracy: 0.6644


 45%|████▌     | 54/120 [50:16<1:00:40, 55.17s/it]

Validation Loss: 1.3368 | Validation Accuracy: 0.6642
Epoch 54 | Train Loss: 1.2602 | Train Accuracy: 0.6644


 46%|████▌     | 55/120 [51:14<1:00:41, 56.02s/it]

Validation Loss: 1.2145 | Validation Accuracy: 0.6642
Epoch 55 | Train Loss: 1.1535 | Train Accuracy: 0.6644


 47%|████▋     | 56/120 [52:09<59:26, 55.72s/it]  

Validation Loss: 1.1269 | Validation Accuracy: 0.6642
Epoch 56 | Train Loss: 1.2918 | Train Accuracy: 0.6627


 48%|████▊     | 57/120 [53:04<58:09, 55.39s/it]

Validation Loss: 1.0709 | Validation Accuracy: 0.6642
Epoch 57 | Train Loss: 1.0361 | Train Accuracy: 0.6644


 48%|████▊     | 58/120 [53:57<56:44, 54.92s/it]

Validation Loss: 1.0384 | Validation Accuracy: 0.6642
Epoch 58 | Train Loss: 1.0101 | Train Accuracy: 0.6644


 49%|████▉     | 59/120 [54:56<57:03, 56.12s/it]

Validation Loss: 1.0197 | Validation Accuracy: 0.6642
Epoch 59 | Train Loss: 0.9949 | Train Accuracy: 0.6644


 50%|█████     | 60/120 [55:55<56:50, 56.84s/it]

Validation Loss: 1.0088 | Validation Accuracy: 0.6642
Epoch 60 | Train Loss: 0.9857 | Train Accuracy: 0.6644


 51%|█████     | 61/120 [56:49<55:01, 55.96s/it]

Validation Loss: 1.0023 | Validation Accuracy: 0.6642
Epoch 61 | Train Loss: 0.9798 | Train Accuracy: 0.6644


 52%|█████▏    | 62/120 [57:44<53:51, 55.72s/it]

Validation Loss: 0.9982 | Validation Accuracy: 0.6642
Epoch 62 | Train Loss: 0.9759 | Train Accuracy: 0.6644


 52%|█████▎    | 63/120 [58:42<53:28, 56.30s/it]

Validation Loss: 0.9956 | Validation Accuracy: 0.6642
Epoch 63 | Train Loss: 0.9732 | Train Accuracy: 0.6644


 53%|█████▎    | 64/120 [59:39<52:46, 56.54s/it]

Validation Loss: 0.9937 | Validation Accuracy: 0.6642
Epoch 64 | Train Loss: 0.9712 | Train Accuracy: 0.6644


 54%|█████▍    | 65/120 [1:00:33<51:21, 56.02s/it]

Validation Loss: 0.9925 | Validation Accuracy: 0.6642
Epoch 65 | Train Loss: 0.9697 | Train Accuracy: 0.6644


 55%|█████▌    | 66/120 [1:01:30<50:35, 56.21s/it]

Validation Loss: 0.9916 | Validation Accuracy: 0.6642
Epoch 66 | Train Loss: 0.9686 | Train Accuracy: 0.6644


 56%|█████▌    | 67/120 [1:02:27<49:54, 56.50s/it]

Validation Loss: 0.9910 | Validation Accuracy: 0.6642
Epoch 67 | Train Loss: 0.9677 | Train Accuracy: 0.6644


 57%|█████▋    | 68/120 [1:03:23<48:39, 56.15s/it]

Validation Loss: 0.9905 | Validation Accuracy: 0.6642
Epoch 68 | Train Loss: 0.9670 | Train Accuracy: 0.6644


 57%|█████▊    | 69/120 [1:04:18<47:32, 55.94s/it]

Validation Loss: 0.9902 | Validation Accuracy: 0.6642
Epoch 69 | Train Loss: 0.9664 | Train Accuracy: 0.6644


 58%|█████▊    | 70/120 [1:05:13<46:17, 55.55s/it]

Validation Loss: 0.9900 | Validation Accuracy: 0.6642
Epoch 70 | Train Loss: 0.9659 | Train Accuracy: 0.6644


 59%|█████▉    | 71/120 [1:06:07<45:06, 55.24s/it]

Validation Loss: 0.9900 | Validation Accuracy: 0.6642
Epoch 71 | Train Loss: 0.9655 | Train Accuracy: 0.6644


 60%|██████    | 72/120 [1:07:09<45:40, 57.09s/it]

Validation Loss: 0.9899 | Validation Accuracy: 0.6642
Epoch 72 | Train Loss: 0.9651 | Train Accuracy: 0.6644


 61%|██████    | 73/120 [1:08:08<45:11, 57.70s/it]

Validation Loss: 0.9899 | Validation Accuracy: 0.6642
Epoch 73 | Train Loss: 0.9648 | Train Accuracy: 0.6644


 62%|██████▏   | 74/120 [1:09:03<43:46, 57.11s/it]

Validation Loss: 0.9900 | Validation Accuracy: 0.6642
Epoch 74 | Train Loss: 0.9646 | Train Accuracy: 0.6644


 62%|██████▎   | 75/120 [1:09:58<42:12, 56.28s/it]

Validation Loss: 0.9901 | Validation Accuracy: 0.6642
Epoch 75 | Train Loss: 0.9643 | Train Accuracy: 0.6644


 63%|██████▎   | 76/120 [1:10:52<40:53, 55.77s/it]

Validation Loss: 0.9903 | Validation Accuracy: 0.6642
Epoch 76 | Train Loss: 0.9641 | Train Accuracy: 0.6644


 64%|██████▍   | 77/120 [1:11:48<39:58, 55.78s/it]

Validation Loss: 0.9904 | Validation Accuracy: 0.6642
Epoch 77 | Train Loss: 0.9640 | Train Accuracy: 0.6644


 65%|██████▌   | 78/120 [1:12:56<41:37, 59.47s/it]

Validation Loss: 0.9906 | Validation Accuracy: 0.6642
Epoch 78 | Train Loss: 0.9638 | Train Accuracy: 0.6644


 66%|██████▌   | 79/120 [1:14:13<44:06, 64.56s/it]

Validation Loss: 0.9908 | Validation Accuracy: 0.6642
Epoch 79 | Train Loss: 0.9637 | Train Accuracy: 0.6644


 67%|██████▋   | 80/120 [1:15:07<40:55, 61.38s/it]

Validation Loss: 0.9910 | Validation Accuracy: 0.6642
Epoch 80 | Train Loss: 0.9636 | Train Accuracy: 0.6644


 68%|██████▊   | 81/120 [1:16:00<38:22, 59.05s/it]

Validation Loss: 0.9912 | Validation Accuracy: 0.6642
Epoch 81 | Train Loss: 0.9635 | Train Accuracy: 0.6644


 68%|██████▊   | 82/120 [1:16:55<36:36, 57.79s/it]

Validation Loss: 0.9915 | Validation Accuracy: 0.6642
Epoch 82 | Train Loss: 0.9634 | Train Accuracy: 0.6644


 69%|██████▉   | 83/120 [1:17:50<35:10, 57.05s/it]

Validation Loss: 0.9917 | Validation Accuracy: 0.6642
Epoch 83 | Train Loss: 0.9633 | Train Accuracy: 0.6644


 70%|███████   | 84/120 [1:18:46<33:56, 56.58s/it]

Validation Loss: 0.9919 | Validation Accuracy: 0.6642
Epoch 84 | Train Loss: 0.9632 | Train Accuracy: 0.6644


 71%|███████   | 85/120 [1:19:46<33:31, 57.48s/it]

Validation Loss: 0.9922 | Validation Accuracy: 0.6642
Epoch 85 | Train Loss: 0.9632 | Train Accuracy: 0.6644


 72%|███████▏  | 86/120 [1:20:41<32:09, 56.75s/it]

Validation Loss: 0.9925 | Validation Accuracy: 0.6642
Epoch 86 | Train Loss: 0.9631 | Train Accuracy: 0.6644


 72%|███████▎  | 87/120 [1:21:35<30:50, 56.09s/it]

Validation Loss: 0.9927 | Validation Accuracy: 0.6642
Epoch 87 | Train Loss: 0.9631 | Train Accuracy: 0.6644


 73%|███████▎  | 88/120 [1:22:29<29:35, 55.47s/it]

Validation Loss: 0.9930 | Validation Accuracy: 0.6642
Epoch 88 | Train Loss: 0.9630 | Train Accuracy: 0.6644


 74%|███████▍  | 89/120 [1:23:26<28:52, 55.90s/it]

Validation Loss: 0.9932 | Validation Accuracy: 0.6642
Epoch 89 | Train Loss: 0.9630 | Train Accuracy: 0.6644


 75%|███████▌  | 90/120 [1:24:21<27:46, 55.55s/it]

Validation Loss: 0.9934 | Validation Accuracy: 0.6642
Epoch 90 | Train Loss: 0.9629 | Train Accuracy: 0.6644


 76%|███████▌  | 91/120 [1:25:16<26:51, 55.58s/it]

Validation Loss: 0.9937 | Validation Accuracy: 0.6642
Epoch 91 | Train Loss: 0.9629 | Train Accuracy: 0.6644


 77%|███████▋  | 92/120 [1:26:12<25:55, 55.57s/it]

Validation Loss: 0.9939 | Validation Accuracy: 0.6642
Epoch 92 | Train Loss: 0.9629 | Train Accuracy: 0.6644


 78%|███████▊  | 93/120 [1:27:07<24:55, 55.37s/it]

Validation Loss: 0.9941 | Validation Accuracy: 0.6642
Epoch 93 | Train Loss: 0.9628 | Train Accuracy: 0.6644


 78%|███████▊  | 94/120 [1:28:02<23:54, 55.19s/it]

Validation Loss: 0.9943 | Validation Accuracy: 0.6642
Epoch 94 | Train Loss: 0.9628 | Train Accuracy: 0.6644


 79%|███████▉  | 95/120 [1:28:57<23:04, 55.38s/it]

Validation Loss: 0.9944 | Validation Accuracy: 0.6642
Epoch 95 | Train Loss: 0.9628 | Train Accuracy: 0.6644


 80%|████████  | 96/120 [1:29:53<22:10, 55.43s/it]

Validation Loss: 0.9946 | Validation Accuracy: 0.6642
Epoch 96 | Train Loss: 0.9628 | Train Accuracy: 0.6644


 81%|████████  | 97/120 [1:30:49<21:16, 55.50s/it]

Validation Loss: 0.9948 | Validation Accuracy: 0.6642
Epoch 97 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 82%|████████▏ | 98/120 [1:31:46<20:32, 56.00s/it]

Validation Loss: 0.9949 | Validation Accuracy: 0.6642
Epoch 98 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 82%|████████▎ | 99/120 [1:32:42<19:36, 56.02s/it]

Validation Loss: 0.9950 | Validation Accuracy: 0.6642
Epoch 99 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 83%|████████▎ | 100/120 [1:33:42<19:02, 57.11s/it]

Validation Loss: 0.9951 | Validation Accuracy: 0.6642
Epoch 100 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 84%|████████▍ | 101/120 [1:34:37<17:54, 56.55s/it]

Validation Loss: 0.9952 | Validation Accuracy: 0.6642
Epoch 101 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 85%|████████▌ | 102/120 [1:35:32<16:51, 56.18s/it]

Validation Loss: 0.9953 | Validation Accuracy: 0.6642
Epoch 102 | Train Loss: 0.9627 | Train Accuracy: 0.6644


 86%|████████▌ | 103/120 [1:36:27<15:49, 55.83s/it]

Validation Loss: 0.9954 | Validation Accuracy: 0.6642
Epoch 103 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 87%|████████▋ | 104/120 [1:37:22<14:49, 55.62s/it]

Validation Loss: 0.9955 | Validation Accuracy: 0.6642
Epoch 104 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 88%|████████▊ | 105/120 [1:38:17<13:50, 55.39s/it]

Validation Loss: 0.9956 | Validation Accuracy: 0.6642
Epoch 105 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 88%|████████▊ | 106/120 [1:39:12<12:53, 55.28s/it]

Validation Loss: 0.9957 | Validation Accuracy: 0.6642
Epoch 106 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 89%|████████▉ | 107/120 [1:40:06<11:53, 54.88s/it]

Validation Loss: 0.9958 | Validation Accuracy: 0.6642
Epoch 107 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 90%|█████████ | 108/120 [1:41:02<11:01, 55.10s/it]

Validation Loss: 0.9958 | Validation Accuracy: 0.6642
Epoch 108 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 91%|█████████ | 109/120 [1:41:56<10:04, 54.92s/it]

Validation Loss: 0.9959 | Validation Accuracy: 0.6642
Epoch 109 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 92%|█████████▏| 110/120 [1:42:51<09:07, 54.76s/it]

Validation Loss: 0.9960 | Validation Accuracy: 0.6642
Epoch 110 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 92%|█████████▎| 111/120 [1:43:46<08:13, 54.83s/it]

Validation Loss: 0.9960 | Validation Accuracy: 0.6642
Epoch 111 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 93%|█████████▎| 112/120 [1:44:40<07:17, 54.63s/it]

Validation Loss: 0.9961 | Validation Accuracy: 0.6642
Epoch 112 | Train Loss: 0.9626 | Train Accuracy: 0.6644


 94%|█████████▍| 113/120 [1:45:35<06:23, 54.76s/it]

Validation Loss: 0.9961 | Validation Accuracy: 0.6642
Epoch 113 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 95%|█████████▌| 114/120 [1:46:30<05:29, 54.85s/it]

Validation Loss: 0.9962 | Validation Accuracy: 0.6642
Epoch 114 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 96%|█████████▌| 115/120 [1:47:25<04:34, 54.90s/it]

Validation Loss: 0.9963 | Validation Accuracy: 0.6642
Epoch 115 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 97%|█████████▋| 116/120 [1:48:21<03:40, 55.12s/it]

Validation Loss: 0.9963 | Validation Accuracy: 0.6642
Epoch 116 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 98%|█████████▊| 117/120 [1:49:16<02:45, 55.27s/it]

Validation Loss: 0.9964 | Validation Accuracy: 0.6642
Epoch 117 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 98%|█████████▊| 118/120 [1:50:11<01:50, 55.21s/it]

Validation Loss: 0.9964 | Validation Accuracy: 0.6642
Epoch 118 | Train Loss: 0.9625 | Train Accuracy: 0.6644


 99%|█████████▉| 119/120 [1:51:06<00:55, 55.13s/it]

Validation Loss: 0.9964 | Validation Accuracy: 0.6642
Epoch 119 | Train Loss: 0.9625 | Train Accuracy: 0.6644


100%|██████████| 120/120 [1:52:03<00:00, 56.03s/it]

Validation Loss: 0.9965 | Validation Accuracy: 0.6642





Test Accuracy: 66 %


In [None]:
# 1. 定義圖神經網路模型
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        # 將每個圖的所有節點特徵進行平均
        hg = dgl.mean_nodes(g, 'h')
        return hg

# 2. 定義數據集
class GraphDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        g = dgl.graph((data["edge_index"][0], data["edge_index"][1]), num_nodes=data["num_nodes"])
        g = dgl.add_self_loop(g)  # add self loop to each node
        g.ndata['feat'] = torch.tensor(data["node_feat"])
        # g.edata['feat'] = torch.tensor(data["edge_attr"])
        return g, torch.tensor(data["y"])


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


with open("../data/final_small_version/remaining_train.jsonl") as f:
    train_data_list = [json.loads(line) for line in f]
train_dataset = GraphDataset(train_data_list)
# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

with open("../data/final_small_version/remaining_valid.jsonl") as f:
    val_data_list = [json.loads(line) for line in f]
val_dataset = GraphDataset(val_data_list)
# val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

with open("../data/final_small_version/remaining_test.jsonl") as f:
    test_data_list = [json.loads(line) for line in f]
test_dataset = GraphDataset(test_data_list)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate)



# 4. 創建模型並訓練
model = GCN(1, 16, 168) # 1是輸入特徵的維度，16是隱藏層大小，168是類別數量
# model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm(range(100)):
    for batched_g, labels in train_dataloader:
        # batched_g, labels = batched_g.to(device), labels.to(device)
        logits = model(batched_g, batched_g.ndata['feat'])
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))

    # 儲存 checkpoint
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, f"./checkpoint/checkpoint_{epoch}.pt")

# 5. 驗證模型
model.eval()
with torch.no_grad():
    for batched_g, labels in val_dataloader:
        logits = model(batched_g, batched_g.ndata['feat'])
        _, predicted = torch.max(logits.data, 1)
        total = labels.size(0)
        correct = (predicted == labels).sum().item()
        print('Accuracy: %d %%' % (100 * correct / total))