# Test of GAT
- use DGL

In [47]:
import dgl
import json
import torch
import torch as th
from tqdm import tqdm
import torch.nn as nn
from dgl.nn import GraphConv, GATConv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

- check the GPU and assign the GPU by the best memory usage

In [48]:
import subprocess
import torch

def get_free_gpu():
    try:
        # Run nvidia-smi command to get GPU details
        _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
        command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
        memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
        memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
        # Get the GPU with the maximum free memory
        best_gpu_id = memory_free_values.index(max(memory_free_values))
        return best_gpu_id
    except:
        # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
        return 0

if torch.cuda.is_available():
    # Get the best GPU ID based on free memory and set it
    best_gpu_id = get_free_gpu()
    device = torch.device(f"cuda:{best_gpu_id}")
else:
    device = torch.device("cpu")

print(device)


cuda:3


## Data Loader

In [76]:
datasets = ['train', 'valid', 'test']
dataloaders = {}

for dataset_name in tqdm(datasets):
    file_path = f"../data/training_data/repeated_{dataset_name}.jsonl"
    
    print(file_path)
    with open(file_path) as f:
        data_list = [json.loads(line) for line in f]
    
    dataset = GraphDataset(data_list, device)
    dataloaders[dataset_name] = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate)
    
print("Done!")

  0%|          | 0/3 [00:00<?, ?it/s]

../data/training_data/repeated_train.jsonl


 33%|███▎      | 1/3 [00:12<00:24, 12.12s/it]

../data/training_data/repeated_valid.jsonl


 67%|██████▋   | 2/3 [00:16<00:07,  7.42s/it]

../data/training_data/repeated_test.jsonl


100%|██████████| 3/3 [00:21<00:00,  7.04s/it]

Done!





In [77]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]

        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"]).to(self.device)
#         g = dgl.add_self_loop(g) 

        g.ndata['feat'] = th.tensor(data["node_feat"]).to(self.device)
        g.edata['feat'] = th.tensor(data["edge_attr"]).to(self.device)  # Add edge features to graph
                
#         edge_attrs = th.tensor(data["edge_attr"]).to(self.device)
#         self_embedding = [[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,]]
#         self_loop_attrs = th.tensor(self_embedding).repeat(g.number_of_nodes()).to(self.device)
#         g.edata['feat'] = th.cat([edge_attrs, self_loop_attrs], dim=0)

        return g, th.tensor(data["label"]).to(self.device)


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


### Model

In [94]:
class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_prob=0.2):
        super(GAT, self).__init__()
        
        # do not check the zero in_degree since we have all the complete graph
        self.layer1 = GATConv(in_dim, hidden_dim, num_heads=num_heads, activation=F.relu, allow_zero_in_degree=True)
        self.layer2 = GATConv(hidden_dim * num_heads, out_dim, num_heads=num_heads, allow_zero_in_degree=True)
        
        # Adding Batch Normalization after each GAT layer
        self.batchnorm1 = nn.BatchNorm1d(hidden_dim * num_heads)
        self.batchnorm2 = nn.BatchNorm1d(out_dim)
        
        # Adding Dropout for regularization
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, g, h):
        # Apply GAT layers
        h = self.layer1(g, h)
        h = h.view(h.shape[0], -1)
        
#         h = self.batchnorm1(h)
        h = F.relu(h)
#         h = self.dropout(h)
        h = self.layer2(g, h).squeeze(1)
        
        # Apply Batch Normalization after second GAT layer
#         h = self.batchnorm2(h)
        
        # Store the output as a new node feature
        g.ndata['h_out'] = h

        # Use mean pooling to aggregate this new node feature
        h_agg = dgl.mean_nodes(g, feat='h_out')
        return h_agg

    

- Model Forward  

In [95]:
def model_fn(data, model, criterion, device, count=1):
    """Forward a batch through the model."""
    batched_g, labels = data
#     print(batch_g)
    batched_g = batched_g.to(device)
    
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT
    logits = logits.mean(dim=1)
#     print(logits)
    
    loss = criterion(logits, labels)
#     print(batched_g.ndata['feat'].dtype)
#     print("Logits shape:", logits.shape)  # Expected: (batch_size, 168)
#     print("Labels shape:", labels.shape)  # Expected: (batch_size)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
#     if count%100==0: print(f"This is {count}-th prediction metrix: {preds}")
#     print(preds)
    
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

    return loss, accuracy

### Training

In [97]:
model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(1 dim, since the design of our dataset), if a node has multiple feature -> in_dim > 1
# out_dim means the # of the categories -> 168 for out tasks

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 100


for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in dataloaders['train']:
#         print(data)
        count += 1
#         if count%300==0: print(data)
        
#         print(data[0])
        loss, accuracy = model_fn(data, model, criterion, device, count)
#         print('hi')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
    scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validate
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# After all epochs
model.eval()
total = 0
correct = 0
with torch.no_grad():
#     for batched_g, labels in test_dataloader:  
    for batched_g, labels in dataloaders['test']:
        batched_g = batched_g.to(device)
        labels = labels.to(device)
        logits = model(batched_g, batched_g.ndata['feat'])
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/100 [00:00<?, ?it/s]

total count: 750
Epoch 0 | Train Loss: 5.1239 | Train Accuracy: 0.1797


  1%|          | 1/100 [00:30<51:07, 30.99s/it]

Validation Loss: 5.1239 | Validation Accuracy: 0.1815
total count: 750
Epoch 1 | Train Loss: 5.1237 | Train Accuracy: 0.1797


  2%|▏         | 2/100 [01:01<50:33, 30.95s/it]

Validation Loss: 5.1234 | Validation Accuracy: 0.1815
total count: 750
Epoch 2 | Train Loss: 5.1228 | Train Accuracy: 0.1797


  3%|▎         | 3/100 [01:32<49:27, 30.59s/it]

Validation Loss: 5.1221 | Validation Accuracy: 0.1815
total count: 750
Epoch 3 | Train Loss: 5.1210 | Train Accuracy: 0.1754


  4%|▍         | 4/100 [02:02<48:40, 30.42s/it]

Validation Loss: 5.1196 | Validation Accuracy: 0.1764
total count: 750
Epoch 4 | Train Loss: 5.1173 | Train Accuracy: 0.1749


  5%|▌         | 5/100 [02:33<48:33, 30.66s/it]

Validation Loss: 5.1148 | Validation Accuracy: 0.1764
total count: 750
Epoch 5 | Train Loss: 5.1108 | Train Accuracy: 0.1749


  6%|▌         | 6/100 [03:02<47:29, 30.31s/it]

Validation Loss: 5.1065 | Validation Accuracy: 0.1764
total count: 750
Epoch 6 | Train Loss: 5.1003 | Train Accuracy: 0.1749


  7%|▋         | 7/100 [03:33<47:12, 30.46s/it]

Validation Loss: 5.0935 | Validation Accuracy: 0.1764
total count: 750
Epoch 7 | Train Loss: 5.0842 | Train Accuracy: 0.1749


  8%|▊         | 8/100 [04:03<46:34, 30.37s/it]

Validation Loss: 5.0743 | Validation Accuracy: 0.1764
total count: 750
Epoch 8 | Train Loss: 5.0610 | Train Accuracy: 0.1749


  9%|▉         | 9/100 [04:34<46:02, 30.35s/it]

Validation Loss: 5.0471 | Validation Accuracy: 0.1764
total count: 750
Epoch 9 | Train Loss: 5.0291 | Train Accuracy: 0.1749


 10%|█         | 10/100 [05:05<45:44, 30.50s/it]

Validation Loss: 5.0105 | Validation Accuracy: 0.1764
total count: 750
Epoch 10 | Train Loss: 4.9874 | Train Accuracy: 0.1749


 11%|█         | 11/100 [05:35<45:24, 30.61s/it]

Validation Loss: 4.9638 | Validation Accuracy: 0.1764
total count: 750
Epoch 11 | Train Loss: 4.9353 | Train Accuracy: 0.1749


 12%|█▏        | 12/100 [06:05<44:31, 30.36s/it]

Validation Loss: 4.9067 | Validation Accuracy: 0.1764
total count: 750
Epoch 12 | Train Loss: 4.8734 | Train Accuracy: 0.1749


 13%|█▎        | 13/100 [06:35<43:52, 30.26s/it]

Validation Loss: 4.8404 | Validation Accuracy: 0.1764
total count: 750
Epoch 13 | Train Loss: 4.8036 | Train Accuracy: 0.1749


 14%|█▍        | 14/100 [07:07<43:53, 30.63s/it]

Validation Loss: 4.7679 | Validation Accuracy: 0.1764
total count: 750
Epoch 14 | Train Loss: 4.7311 | Train Accuracy: 0.1749


 15%|█▌        | 15/100 [07:37<43:04, 30.41s/it]

Validation Loss: 4.6964 | Validation Accuracy: 0.1764
total count: 750
Epoch 15 | Train Loss: 4.6650 | Train Accuracy: 0.1749


 16%|█▌        | 16/100 [08:07<42:27, 30.33s/it]

Validation Loss: 4.6366 | Validation Accuracy: 0.1764
total count: 750
Epoch 16 | Train Loss: 4.6156 | Train Accuracy: 0.1749


 17%|█▋        | 17/100 [08:37<41:55, 30.31s/it]

Validation Loss: 4.5967 | Validation Accuracy: 0.1764
total count: 750
Epoch 17 | Train Loss: 4.5850 | Train Accuracy: 0.1749


 18%|█▊        | 18/100 [09:08<41:36, 30.44s/it]

Validation Loss: 4.5724 | Validation Accuracy: 0.1764
total count: 750
Epoch 18 | Train Loss: 4.5643 | Train Accuracy: 0.1749


 19%|█▉        | 19/100 [09:39<41:14, 30.54s/it]

Validation Loss: 4.5531 | Validation Accuracy: 0.1764
total count: 750
Epoch 19 | Train Loss: 4.5455 | Train Accuracy: 0.1749


 20%|██        | 20/100 [10:09<40:29, 30.37s/it]

Validation Loss: 4.5339 | Validation Accuracy: 0.1764
total count: 750
Epoch 20 | Train Loss: 4.5259 | Train Accuracy: 0.1749


 21%|██        | 21/100 [10:37<39:16, 29.83s/it]

Validation Loss: 4.5136 | Validation Accuracy: 0.1764
total count: 750
Epoch 21 | Train Loss: 4.5050 | Train Accuracy: 0.1749


 22%|██▏       | 22/100 [11:07<38:52, 29.90s/it]

Validation Loss: 4.4918 | Validation Accuracy: 0.1764
total count: 750
Epoch 22 | Train Loss: 4.4827 | Train Accuracy: 0.1749


 23%|██▎       | 23/100 [11:40<39:36, 30.86s/it]

Validation Loss: 4.4685 | Validation Accuracy: 0.1764
total count: 750
Epoch 23 | Train Loss: 4.4589 | Train Accuracy: 0.1749


 24%|██▍       | 24/100 [12:12<39:34, 31.24s/it]

Validation Loss: 4.4439 | Validation Accuracy: 0.1764
total count: 750
Epoch 24 | Train Loss: 4.4337 | Train Accuracy: 0.1762


 25%|██▌       | 25/100 [12:40<37:52, 30.30s/it]

Validation Loss: 4.4180 | Validation Accuracy: 0.1812
total count: 750
Epoch 25 | Train Loss: 4.4074 | Train Accuracy: 0.1803


 26%|██▌       | 26/100 [13:10<37:13, 30.18s/it]

Validation Loss: 4.3910 | Validation Accuracy: 0.1812
total count: 750
Epoch 26 | Train Loss: 4.3800 | Train Accuracy: 0.1803


 27%|██▋       | 27/100 [13:44<37:51, 31.11s/it]

Validation Loss: 4.3632 | Validation Accuracy: 0.1812
total count: 750
Epoch 27 | Train Loss: 4.3518 | Train Accuracy: 0.1803


 28%|██▊       | 28/100 [14:14<37:11, 30.99s/it]

Validation Loss: 4.3346 | Validation Accuracy: 0.1812
total count: 750
Epoch 28 | Train Loss: 4.3230 | Train Accuracy: 0.1843


 29%|██▉       | 29/100 [14:43<35:57, 30.39s/it]

Validation Loss: 4.3054 | Validation Accuracy: 0.1859
total count: 750
Epoch 29 | Train Loss: 4.2936 | Train Accuracy: 0.1881


 30%|███       | 30/100 [15:12<34:58, 29.98s/it]

Validation Loss: 4.2759 | Validation Accuracy: 0.1910
total count: 750
Epoch 30 | Train Loss: 4.2639 | Train Accuracy: 0.1931


 31%|███       | 31/100 [15:45<35:23, 30.78s/it]

Validation Loss: 4.2462 | Validation Accuracy: 0.1961
total count: 750
Epoch 31 | Train Loss: 4.2341 | Train Accuracy: 0.1997


 32%|███▏      | 32/100 [16:17<35:17, 31.14s/it]

Validation Loss: 4.2164 | Validation Accuracy: 0.2019
total count: 750
Epoch 32 | Train Loss: 4.2042 | Train Accuracy: 0.2001


 33%|███▎      | 33/100 [16:49<35:03, 31.39s/it]

Validation Loss: 4.1867 | Validation Accuracy: 0.2019
total count: 750
Epoch 33 | Train Loss: 4.1744 | Train Accuracy: 0.2001


 34%|███▍      | 34/100 [17:18<33:53, 30.81s/it]

Validation Loss: 4.1571 | Validation Accuracy: 0.2019
total count: 750
Epoch 34 | Train Loss: 4.1447 | Train Accuracy: 0.2036


 35%|███▌      | 35/100 [17:49<33:09, 30.61s/it]

Validation Loss: 4.1277 | Validation Accuracy: 0.2069
total count: 750
Epoch 35 | Train Loss: 4.1152 | Train Accuracy: 0.2065


 36%|███▌      | 36/100 [18:21<33:16, 31.20s/it]

Validation Loss: 4.0985 | Validation Accuracy: 0.2114
total count: 750
Epoch 36 | Train Loss: 4.0860 | Train Accuracy: 0.2127


 37%|███▋      | 37/100 [18:54<33:08, 31.57s/it]

Validation Loss: 4.0696 | Validation Accuracy: 0.2162
total count: 750
Epoch 37 | Train Loss: 4.0571 | Train Accuracy: 0.2223


 38%|███▊      | 38/100 [19:26<32:56, 31.88s/it]

Validation Loss: 4.0410 | Validation Accuracy: 0.2220
total count: 750
Epoch 38 | Train Loss: 4.0287 | Train Accuracy: 0.2281


 39%|███▉      | 39/100 [19:59<32:43, 32.19s/it]

Validation Loss: 4.0128 | Validation Accuracy: 0.2264
total count: 750
Epoch 39 | Train Loss: 4.0006 | Train Accuracy: 0.2322


 40%|████      | 40/100 [20:30<31:47, 31.79s/it]

Validation Loss: 3.9850 | Validation Accuracy: 0.2367
total count: 750
Epoch 40 | Train Loss: 3.9730 | Train Accuracy: 0.2362


 41%|████      | 41/100 [21:00<30:50, 31.37s/it]

Validation Loss: 3.9576 | Validation Accuracy: 0.2367
total count: 750
Epoch 41 | Train Loss: 3.9458 | Train Accuracy: 0.2379


 42%|████▏     | 42/100 [21:30<29:44, 30.76s/it]

Validation Loss: 3.9306 | Validation Accuracy: 0.2367
total count: 750
Epoch 42 | Train Loss: 3.9192 | Train Accuracy: 0.2417


 43%|████▎     | 43/100 [21:59<28:43, 30.23s/it]

Validation Loss: 3.9041 | Validation Accuracy: 0.2466
total count: 750
Epoch 43 | Train Loss: 3.8931 | Train Accuracy: 0.2456


 44%|████▍     | 44/100 [22:28<28:02, 30.05s/it]

Validation Loss: 3.8780 | Validation Accuracy: 0.2516
total count: 750
Epoch 44 | Train Loss: 3.8675 | Train Accuracy: 0.2532


 45%|████▌     | 45/100 [22:57<27:06, 29.58s/it]

Validation Loss: 3.8524 | Validation Accuracy: 0.2667
total count: 750
Epoch 45 | Train Loss: 3.8424 | Train Accuracy: 0.2599


 46%|████▌     | 46/100 [23:27<26:53, 29.88s/it]

Validation Loss: 3.8273 | Validation Accuracy: 0.2716
total count: 750
Epoch 46 | Train Loss: 3.8178 | Train Accuracy: 0.2642


 47%|████▋     | 47/100 [23:59<26:48, 30.35s/it]

Validation Loss: 3.8027 | Validation Accuracy: 0.2716
total count: 750
Epoch 47 | Train Loss: 3.7938 | Train Accuracy: 0.2695


 48%|████▊     | 48/100 [24:31<26:50, 30.98s/it]

Validation Loss: 3.7786 | Validation Accuracy: 0.2716
total count: 750
Epoch 48 | Train Loss: 3.7703 | Train Accuracy: 0.2757


 49%|████▉     | 49/100 [25:03<26:38, 31.34s/it]

Validation Loss: 3.7550 | Validation Accuracy: 0.2874
total count: 750
Epoch 49 | Train Loss: 3.7473 | Train Accuracy: 0.2866


 50%|█████     | 50/100 [25:37<26:32, 31.86s/it]

Validation Loss: 3.7318 | Validation Accuracy: 0.2983
total count: 750
Epoch 50 | Train Loss: 3.7247 | Train Accuracy: 0.3007


 51%|█████     | 51/100 [26:06<25:21, 31.06s/it]

Validation Loss: 3.7091 | Validation Accuracy: 0.3240
total count: 750
Epoch 51 | Train Loss: 3.7025 | Train Accuracy: 0.3138


 52%|█████▏    | 52/100 [26:38<25:10, 31.46s/it]

Validation Loss: 3.6867 | Validation Accuracy: 0.3240
total count: 750
Epoch 52 | Train Loss: 3.6806 | Train Accuracy: 0.3188


 53%|█████▎    | 53/100 [27:09<24:27, 31.21s/it]

Validation Loss: 3.6647 | Validation Accuracy: 0.3275
total count: 750
Epoch 53 | Train Loss: 3.6590 | Train Accuracy: 0.3297


 54%|█████▍    | 54/100 [27:38<23:35, 30.77s/it]

Validation Loss: 3.6430 | Validation Accuracy: 0.3459
total count: 750
Epoch 54 | Train Loss: 3.6377 | Train Accuracy: 0.3410


 55%|█████▌    | 55/100 [28:11<23:23, 31.20s/it]

Validation Loss: 3.6215 | Validation Accuracy: 0.3520
total count: 750
Epoch 55 | Train Loss: 3.6165 | Train Accuracy: 0.3518


 56%|█████▌    | 56/100 [28:42<22:50, 31.15s/it]

Validation Loss: 3.6002 | Validation Accuracy: 0.3611
total count: 750
Epoch 56 | Train Loss: 3.5955 | Train Accuracy: 0.3635


 57%|█████▋    | 57/100 [29:13<22:14, 31.05s/it]

Validation Loss: 3.5791 | Validation Accuracy: 0.3668
total count: 750
Epoch 57 | Train Loss: 3.5746 | Train Accuracy: 0.3738


 58%|█████▊    | 58/100 [29:44<21:45, 31.09s/it]

Validation Loss: 3.5581 | Validation Accuracy: 0.3757
total count: 750
Epoch 58 | Train Loss: 3.5538 | Train Accuracy: 0.3805


 59%|█████▉    | 59/100 [30:14<21:05, 30.87s/it]

Validation Loss: 3.5373 | Validation Accuracy: 0.3757
total count: 750
Epoch 59 | Train Loss: 3.5330 | Train Accuracy: 0.3874


 60%|██████    | 60/100 [30:42<19:58, 29.95s/it]

Validation Loss: 3.5165 | Validation Accuracy: 0.3807
total count: 750
Epoch 60 | Train Loss: 3.5123 | Train Accuracy: 0.3948


 61%|██████    | 61/100 [31:12<19:34, 30.11s/it]

Validation Loss: 3.4957 | Validation Accuracy: 0.3944
total count: 750
Epoch 61 | Train Loss: 3.4916 | Train Accuracy: 0.4023


 62%|██████▏   | 62/100 [31:42<18:57, 29.94s/it]

Validation Loss: 3.4750 | Validation Accuracy: 0.4080
total count: 750
Epoch 62 | Train Loss: 3.4709 | Train Accuracy: 0.4111


 63%|██████▎   | 63/100 [32:14<18:49, 30.52s/it]

Validation Loss: 3.4543 | Validation Accuracy: 0.4230
total count: 750
Epoch 63 | Train Loss: 3.4501 | Train Accuracy: 0.4195


 64%|██████▍   | 64/100 [32:45<18:26, 30.74s/it]

Validation Loss: 3.4336 | Validation Accuracy: 0.4230
total count: 750
Epoch 64 | Train Loss: 3.4294 | Train Accuracy: 0.4246


 65%|██████▌   | 65/100 [33:13<17:25, 29.87s/it]

Validation Loss: 3.4129 | Validation Accuracy: 0.4230
total count: 750
Epoch 65 | Train Loss: 3.4086 | Train Accuracy: 0.4290


 66%|██████▌   | 66/100 [33:42<16:45, 29.57s/it]

Validation Loss: 3.3922 | Validation Accuracy: 0.4230
total count: 750
Epoch 66 | Train Loss: 3.3877 | Train Accuracy: 0.4328


 67%|██████▋   | 67/100 [34:13<16:31, 30.06s/it]

Validation Loss: 3.3715 | Validation Accuracy: 0.4329
total count: 750
Epoch 67 | Train Loss: 3.3669 | Train Accuracy: 0.4381


 68%|██████▊   | 68/100 [34:44<16:16, 30.50s/it]

Validation Loss: 3.3509 | Validation Accuracy: 0.4329
total count: 750
Epoch 68 | Train Loss: 3.3460 | Train Accuracy: 0.4430


 69%|██████▉   | 69/100 [35:16<15:51, 30.68s/it]

Validation Loss: 3.3302 | Validation Accuracy: 0.4329
total count: 750
Epoch 69 | Train Loss: 3.3252 | Train Accuracy: 0.4476


 70%|███████   | 70/100 [35:48<15:40, 31.35s/it]

Validation Loss: 3.3095 | Validation Accuracy: 0.4439
total count: 750
Epoch 70 | Train Loss: 3.3042 | Train Accuracy: 0.4501


 71%|███████   | 71/100 [36:20<15:14, 31.53s/it]

Validation Loss: 3.2887 | Validation Accuracy: 0.4439
total count: 750
Epoch 71 | Train Loss: 3.2833 | Train Accuracy: 0.4523


 72%|███████▏  | 72/100 [36:53<14:53, 31.89s/it]

Validation Loss: 3.2680 | Validation Accuracy: 0.4439
total count: 750
Epoch 72 | Train Loss: 3.2624 | Train Accuracy: 0.4545


 73%|███████▎  | 73/100 [37:27<14:36, 32.45s/it]

Validation Loss: 3.2473 | Validation Accuracy: 0.4439
total count: 750
Epoch 73 | Train Loss: 3.2414 | Train Accuracy: 0.4558


 74%|███████▍  | 74/100 [37:58<13:54, 32.11s/it]

Validation Loss: 3.2267 | Validation Accuracy: 0.4439
total count: 750
Epoch 74 | Train Loss: 3.2206 | Train Accuracy: 0.4574


 75%|███████▌  | 75/100 [38:27<13:00, 31.24s/it]

Validation Loss: 3.2062 | Validation Accuracy: 0.4439
total count: 750
Epoch 75 | Train Loss: 3.1998 | Train Accuracy: 0.4584


 76%|███████▌  | 76/100 [38:59<12:30, 31.27s/it]

Validation Loss: 3.1857 | Validation Accuracy: 0.4439
total count: 750
Epoch 76 | Train Loss: 3.1791 | Train Accuracy: 0.4611


 77%|███████▋  | 77/100 [39:33<12:16, 32.00s/it]

Validation Loss: 3.1654 | Validation Accuracy: 0.4579
total count: 750
Epoch 77 | Train Loss: 3.1585 | Train Accuracy: 0.4670


 78%|███████▊  | 78/100 [40:06<11:50, 32.31s/it]

Validation Loss: 3.1452 | Validation Accuracy: 0.4579
total count: 750
Epoch 78 | Train Loss: 3.1381 | Train Accuracy: 0.4691


 79%|███████▉  | 79/100 [40:36<11:09, 31.87s/it]

Validation Loss: 3.1252 | Validation Accuracy: 0.4579
total count: 750
Epoch 79 | Train Loss: 3.1179 | Train Accuracy: 0.4702


 80%|████████  | 80/100 [41:09<10:42, 32.14s/it]

Validation Loss: 3.1055 | Validation Accuracy: 0.4642
total count: 750
Epoch 80 | Train Loss: 3.0980 | Train Accuracy: 0.4715


 81%|████████  | 81/100 [41:41<10:11, 32.18s/it]

Validation Loss: 3.0860 | Validation Accuracy: 0.4690
total count: 750
Epoch 81 | Train Loss: 3.0783 | Train Accuracy: 0.4727


 82%|████████▏ | 82/100 [42:12<09:31, 31.74s/it]

Validation Loss: 3.0668 | Validation Accuracy: 0.4690
total count: 750
Epoch 82 | Train Loss: 3.0589 | Train Accuracy: 0.4733


 83%|████████▎ | 83/100 [42:45<09:03, 31.97s/it]

Validation Loss: 3.0479 | Validation Accuracy: 0.4690
total count: 750
Epoch 83 | Train Loss: 3.0398 | Train Accuracy: 0.4743


 84%|████████▍ | 84/100 [43:17<08:34, 32.15s/it]

Validation Loss: 3.0293 | Validation Accuracy: 0.4690
total count: 750
Epoch 84 | Train Loss: 3.0211 | Train Accuracy: 0.4751


 85%|████████▌ | 85/100 [43:49<07:59, 31.98s/it]

Validation Loss: 3.0111 | Validation Accuracy: 0.4690
total count: 750
Epoch 85 | Train Loss: 3.0028 | Train Accuracy: 0.4759


 86%|████████▌ | 86/100 [44:19<07:22, 31.58s/it]

Validation Loss: 2.9933 | Validation Accuracy: 0.4690
total count: 750
Epoch 86 | Train Loss: 2.9848 | Train Accuracy: 0.4779


 87%|████████▋ | 87/100 [44:50<06:46, 31.24s/it]

Validation Loss: 2.9759 | Validation Accuracy: 0.4743
total count: 750
Epoch 87 | Train Loss: 2.9673 | Train Accuracy: 0.4792


 88%|████████▊ | 88/100 [45:19<06:08, 30.71s/it]

Validation Loss: 2.9590 | Validation Accuracy: 0.4796
total count: 750
Epoch 88 | Train Loss: 2.9503 | Train Accuracy: 0.4802


 89%|████████▉ | 89/100 [45:51<05:40, 30.98s/it]

Validation Loss: 2.9425 | Validation Accuracy: 0.4796
total count: 750
Epoch 89 | Train Loss: 2.9337 | Train Accuracy: 0.4808


 90%|█████████ | 90/100 [46:23<05:14, 31.41s/it]

Validation Loss: 2.9264 | Validation Accuracy: 0.4796
total count: 750
Epoch 90 | Train Loss: 2.9176 | Train Accuracy: 0.4810


 91%|█████████ | 91/100 [46:58<04:50, 32.31s/it]

Validation Loss: 2.9108 | Validation Accuracy: 0.4796
total count: 750
Epoch 91 | Train Loss: 2.9019 | Train Accuracy: 0.4812


 92%|█████████▏| 92/100 [47:30<04:18, 32.35s/it]

Validation Loss: 2.8958 | Validation Accuracy: 0.4796
total count: 750
Epoch 92 | Train Loss: 2.8868 | Train Accuracy: 0.4814


 93%|█████████▎| 93/100 [48:00<03:42, 31.71s/it]

Validation Loss: 2.8812 | Validation Accuracy: 0.4796
total count: 750
Epoch 93 | Train Loss: 2.8722 | Train Accuracy: 0.4815


 94%|█████████▍| 94/100 [48:32<03:10, 31.71s/it]

Validation Loss: 2.8671 | Validation Accuracy: 0.4796
total count: 750
Epoch 94 | Train Loss: 2.8582 | Train Accuracy: 0.4817


 95%|█████████▌| 95/100 [49:03<02:37, 31.43s/it]

Validation Loss: 2.8535 | Validation Accuracy: 0.4796
total count: 750
Epoch 95 | Train Loss: 2.8446 | Train Accuracy: 0.4818


 96%|█████████▌| 96/100 [49:32<02:03, 30.85s/it]

Validation Loss: 2.8404 | Validation Accuracy: 0.4796
total count: 750
Epoch 96 | Train Loss: 2.8316 | Train Accuracy: 0.4819


 97%|█████████▋| 97/100 [50:03<01:32, 30.87s/it]

Validation Loss: 2.8278 | Validation Accuracy: 0.4796
total count: 750
Epoch 97 | Train Loss: 2.8191 | Train Accuracy: 0.4820


 98%|█████████▊| 98/100 [50:34<01:01, 30.67s/it]

Validation Loss: 2.8157 | Validation Accuracy: 0.4796
total count: 750
Epoch 98 | Train Loss: 2.8071 | Train Accuracy: 0.4820


 99%|█████████▉| 99/100 [51:06<00:31, 31.25s/it]

Validation Loss: 2.8042 | Validation Accuracy: 0.4796
total count: 750
Epoch 99 | Train Loss: 2.7956 | Train Accuracy: 0.4820


100%|██████████| 100/100 [51:37<00:00, 30.97s/it]

Validation Loss: 2.7931 | Validation Accuracy: 0.4796





RuntimeError: The size of tensor a (168) must match the size of tensor b (32) at non-singleton dimension 1

In [None]:
# 1. 定義圖神經網路模型
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        # 將每個圖的所有節點特徵進行平均
        hg = dgl.mean_nodes(g, 'h')
        return hg

# 2. 定義數據集
class GraphDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        g = dgl.graph((data["edge_index"][0], data["edge_index"][1]), num_nodes=data["num_nodes"])
        g = dgl.add_self_loop(g)  # add self loop to each node
        g.ndata['feat'] = torch.tensor(data["node_feat"])
        # g.edata['feat'] = torch.tensor(data["edge_attr"])
        return g, torch.tensor(data["y"])


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


with open("../data/final_small_version/remaining_train.jsonl") as f:
    train_data_list = [json.loads(line) for line in f]
train_dataset = GraphDataset(train_data_list)
# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

with open("../data/final_small_version/remaining_valid.jsonl") as f:
    val_data_list = [json.loads(line) for line in f]
val_dataset = GraphDataset(val_data_list)
# val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

with open("../data/final_small_version/remaining_test.jsonl") as f:
    test_data_list = [json.loads(line) for line in f]
test_dataset = GraphDataset(test_data_list)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=dgl.batch)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate)



# 4. 創建模型並訓練
model = GCN(1, 16, 168) # 1是輸入特徵的維度，16是隱藏層大小，168是類別數量
# model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm(range(100)):
    for batched_g, labels in train_dataloader:
        # batched_g, labels = batched_g.to(device), labels.to(device)
        logits = model(batched_g, batched_g.ndata['feat'])
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))

    # 儲存 checkpoint
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, f"./checkpoint_GAT/checkpoint_{epoch}.pt")

# 5. 驗證模型
model.eval()
with torch.no_grad():
    for batched_g, labels in val_dataloader:
        logits = model(batched_g, batched_g.ndata['feat'])
        _, predicted = torch.max(logits.data, 1)
        total = labels.size(0)
        correct = (predicted == labels).sum().item()
        print('Accuracy: %d %%' % (100 * correct / total))