# Test of GAT
- use DGL

In [1]:
import dgl
import json
import torch
import torch as th
from tqdm import tqdm
import torch.nn as nn
from dgl.nn import GraphConv, GATConv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

- check the GPU and assign the GPU by the best memory usage

In [2]:
import subprocess
import torch

def get_free_gpu():
    try:
        # Run nvidia-smi command to get GPU details
        _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
        command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
        memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
        memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
        # Get the GPU with the maximum free memory
        best_gpu_id = memory_free_values.index(max(memory_free_values))
        return best_gpu_id
    except:
        # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
        return 0

if torch.cuda.is_available():
    # Get the best GPU ID based on free memory and set it
    best_gpu_id = get_free_gpu()
    device = torch.device(f"cuda:{best_gpu_id}")
else:
    device = torch.device("cpu")

print(device)


cuda:1


## Data Loader

In [38]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]

        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"]).to(self.device)

        g.ndata['feat'] = th.tensor(data["node_feat"]).to(self.device)
        g.edata['feat'] = th.tensor(data["edge_attr"]).to(self.device)  # Add edge features to graph

        return g, th.tensor(data["label"]).to(self.device)


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


In [37]:
datasets = ['train', 'valid', 'test']
dataloaders = {}

for dataset_name in tqdm(datasets):
#     file_path = f"../data/training_data/repeated_{dataset_name}.jsonl"
    file_path = f"../data/test_1/repeated_{dataset_name}.jsonl"
    
    print(file_path)
    with open(file_path) as f:
        data_list = [json.loads(line) for line in f]
    
    dataset = GraphDataset(data_list, device)
    dataloaders[dataset_name] = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate)
    
print("Done!")

 33%|███▎      | 1/3 [00:00<00:00,  5.33it/s]

../data/test_1/repeated_train.jsonl
../data/test_1/repeated_valid.jsonl


100%|██████████| 3/3 [00:00<00:00,  9.91it/s]

../data/test_1/repeated_test.jsonl
Done!





### Model

In [39]:
class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_prob=0.2):
        super(GAT, self).__init__()
        
        # do not check the zero in_degree since we have all the complete graph
        self.layer1 = GATConv(in_dim, hidden_dim, num_heads=num_heads, activation=F.relu, allow_zero_in_degree=True)
        self.layer2 = GATConv(hidden_dim * num_heads, out_dim, num_heads=num_heads, allow_zero_in_degree=True)
        
        # Adding Batch Normalization after each GAT layer
        self.batchnorm1 = nn.BatchNorm1d(hidden_dim * num_heads)
        self.batchnorm2 = nn.BatchNorm1d(out_dim)
        
        # Adding Dropout for regularization
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, g, h):
        # Apply GAT layers
        h = self.layer1(g, h)
        h = h.view(h.shape[0], -1)
        h = F.relu(h)
        h = self.layer2(g, h).squeeze(1)
        
        # Store the output as a new node feature
        g.ndata['h_out'] = h

        # Use mean pooling to aggregate this new node feature
        h_agg = dgl.mean_nodes(g, feat='h_out')
        return h_agg

    

- Model Forward  

In [40]:
def model_fn(data, model, criterion, device, count=1):
    """Forward a batch through the model."""
    batched_g, labels = data
#     print(batch_g)
    batched_g = batched_g.to(device)
    
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT
    logits = logits.mean(dim=1)
#     print(logits)
    
    loss = criterion(logits, labels)
#     print(batched_g.ndata['feat'].dtype)
#     print("Logits shape:", logits.shape)  # Expected: (batch_size, 168)
#     print("Labels shape:", labels.shape)  # Expected: (batch_size)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
    
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

#     return loss, accuracy
    return loss, accuracy, preds

### Training

- 165 APs + benign x 50 times

In [36]:
model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 6/6000 [00:00<01:42, 58.67it/s][A
  0%|          | 12/6000 [00:00<02:05, 47.74it/s][A
  0%|          | 17/6000 [00:00<02:08, 46.68it/s][A
  0%|          | 22/6000 [00:00<02:09, 46.16it/s][A
  0%|          | 27/6000 [00:00<02:10, 45.83it/s][A
  1%|          | 32/6000 [00:00<02:11, 45.54it/s][A
  1%|          | 37/6000 [00:00<02:11, 45.22it/s][A
  1%|          | 42/6000 [00:00<02:13, 44.49it/s][A
  1%|          | 47/6000 [00:01<02:10, 45.62it/s][A
  1%|          | 52/6000 [00:01<02:10, 45.75it/s][A
  1%|          | 57/6000 [00:01<02:10, 45.37it/s][A
  1%|          | 62/6000 [00:01<02:11, 45.24it/s][A
  1%|          | 67/6000 [00:01<02:09, 45.76it/s][A
  1%|          | 72/6000 [00:01<02:07, 46.41it/s][A
  1%|▏         | 77/6000 [00:01<02:08, 45.92it/s][A
  1%|▏         | 82/6000 [00:01<02:10, 45.38it/s][A
  1%|▏         | 87/6000 [00:01<02:09, 45.52it/s][A
  2%|▏         | 9

 25%|██▌       | 1514/6000 [00:33<01:48, 41.19it/s][A
 25%|██▌       | 1519/6000 [00:33<01:47, 41.63it/s][A
 25%|██▌       | 1524/6000 [00:33<01:47, 41.58it/s][A
 25%|██▌       | 1529/6000 [00:33<01:47, 41.54it/s][A
 26%|██▌       | 1534/6000 [00:33<01:47, 41.60it/s][A
 26%|██▌       | 1539/6000 [00:33<01:50, 40.38it/s][A
 26%|██▌       | 1544/6000 [00:33<01:49, 40.71it/s][A
 26%|██▌       | 1549/6000 [00:34<01:49, 40.62it/s][A
 26%|██▌       | 1554/6000 [00:34<01:47, 41.39it/s][A
 26%|██▌       | 1559/6000 [00:34<01:46, 41.60it/s][A
 26%|██▌       | 1564/6000 [00:34<01:46, 41.76it/s][A
 26%|██▌       | 1569/6000 [00:34<01:45, 42.06it/s][A
 26%|██▌       | 1574/6000 [00:34<01:44, 42.45it/s][A
 26%|██▋       | 1579/6000 [00:34<01:43, 42.83it/s][A
 26%|██▋       | 1584/6000 [00:34<01:45, 41.75it/s][A
 26%|██▋       | 1589/6000 [00:34<01:43, 42.54it/s][A
 27%|██▋       | 1594/6000 [00:35<01:45, 41.80it/s][A
 27%|██▋       | 1599/6000 [00:35<01:45, 41.56it/s][A
 27%|██▋  

 50%|████▉     | 2996/6000 [01:08<01:13, 41.02it/s][A
 50%|█████     | 3001/6000 [01:08<01:10, 42.24it/s][A
 50%|█████     | 3006/6000 [01:08<01:10, 42.28it/s][A
 50%|█████     | 3011/6000 [01:08<01:11, 41.88it/s][A
 50%|█████     | 3016/6000 [01:08<01:11, 41.58it/s][A
 50%|█████     | 3021/6000 [01:09<01:11, 41.46it/s][A
 50%|█████     | 3026/6000 [01:09<01:10, 42.47it/s][A
 51%|█████     | 3031/6000 [01:09<01:08, 43.61it/s][A
 51%|█████     | 3036/6000 [01:09<01:07, 43.83it/s][A
 51%|█████     | 3041/6000 [01:09<01:07, 43.62it/s][A
 51%|█████     | 3046/6000 [01:09<01:08, 43.34it/s][A
 51%|█████     | 3051/6000 [01:09<01:08, 43.21it/s][A
 51%|█████     | 3056/6000 [01:09<01:09, 42.23it/s][A
 51%|█████     | 3061/6000 [01:09<01:09, 42.54it/s][A
 51%|█████     | 3066/6000 [01:10<01:09, 42.32it/s][A
 51%|█████     | 3071/6000 [01:10<01:09, 41.93it/s][A
 51%|█████▏    | 3076/6000 [01:10<01:10, 41.60it/s][A
 51%|█████▏    | 3081/6000 [01:10<01:10, 41.54it/s][A
 51%|█████

 75%|███████▍  | 4478/6000 [01:43<00:36, 41.48it/s][A
 75%|███████▍  | 4483/6000 [01:43<00:36, 41.38it/s][A
 75%|███████▍  | 4488/6000 [01:43<00:36, 41.44it/s][A
 75%|███████▍  | 4493/6000 [01:44<00:36, 41.43it/s][A
 75%|███████▍  | 4498/6000 [01:44<00:36, 41.46it/s][A
 75%|███████▌  | 4503/6000 [01:44<00:36, 41.42it/s][A
 75%|███████▌  | 4508/6000 [01:44<00:35, 41.55it/s][A
 75%|███████▌  | 4513/6000 [01:44<00:35, 41.51it/s][A
 75%|███████▌  | 4518/6000 [01:44<00:35, 41.18it/s][A
 75%|███████▌  | 4523/6000 [01:44<00:35, 41.32it/s][A
 75%|███████▌  | 4528/6000 [01:44<00:35, 41.22it/s][A
 76%|███████▌  | 4533/6000 [01:44<00:35, 41.09it/s][A
 76%|███████▌  | 4538/6000 [01:45<00:36, 40.33it/s][A
 76%|███████▌  | 4543/6000 [01:45<00:35, 40.70it/s][A
 76%|███████▌  | 4548/6000 [01:45<00:35, 40.70it/s][A
 76%|███████▌  | 4553/6000 [01:45<00:34, 42.16it/s][A
 76%|███████▌  | 4558/6000 [01:45<00:33, 43.68it/s][A
 76%|███████▌  | 4563/6000 [01:45<00:33, 42.78it/s][A
 76%|█████

 99%|█████████▉| 5960/6000 [02:18<00:00, 41.60it/s][A
 99%|█████████▉| 5965/6000 [02:18<00:00, 41.25it/s][A
100%|█████████▉| 5970/6000 [02:19<00:00, 41.02it/s][A
100%|█████████▉| 5975/6000 [02:19<00:00, 41.43it/s][A
100%|█████████▉| 5980/6000 [02:19<00:00, 41.22it/s][A
100%|█████████▉| 5985/6000 [02:19<00:00, 42.12it/s][A
100%|█████████▉| 5990/6000 [02:19<00:00, 42.59it/s][A
100%|█████████▉| 5995/6000 [02:19<00:00, 41.95it/s][A
100%|██████████| 6000/6000 [02:19<00:00, 42.94it/s][A


total count: 6000
Epoch 0 | Train Loss: 4.7759 | Train Accuracy: 0.1667


 20%|██        | 1/5 [02:27<09:51, 147.97s/it]

Validation Loss: 4.5674 | Validation Accuracy: 0.1750



  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 5/6000 [00:00<02:22, 42.05it/s][A
  0%|          | 10/6000 [00:00<02:22, 42.11it/s][A
  0%|          | 15/6000 [00:00<02:21, 42.28it/s][A
  0%|          | 20/6000 [00:00<02:23, 41.53it/s][A
  0%|          | 25/6000 [00:00<02:17, 43.31it/s][A
  0%|          | 30/6000 [00:00<02:13, 44.78it/s][A
  1%|          | 35/6000 [00:00<02:13, 44.53it/s][A
  1%|          | 40/6000 [00:00<02:16, 43.73it/s][A
  1%|          | 45/6000 [00:01<02:17, 43.21it/s][A
  1%|          | 50/6000 [00:01<02:17, 43.31it/s][A
  1%|          | 55/6000 [00:01<02:20, 42.32it/s][A
  1%|          | 60/6000 [00:01<02:22, 41.83it/s][A
  1%|          | 65/6000 [00:01<02:22, 41.63it/s][A
  1%|          | 70/6000 [00:01<02:21, 41.91it/s][A
  1%|▏         | 75/6000 [00:01<02:22, 41.64it/s][A
  1%|▏         | 80/6000 [00:01<02:22, 41.55it/s][A
  1%|▏         | 85/6000 [00:02<02:21, 41.69it/s][A
  2%|▏         | 90/6000 [00:02<02:21, 41.73it/s][A
  

 25%|██▌       | 1508/6000 [00:35<01:48, 41.33it/s][A
 25%|██▌       | 1513/6000 [00:36<01:47, 41.57it/s][A
 25%|██▌       | 1518/6000 [00:36<01:48, 41.31it/s][A
 25%|██▌       | 1523/6000 [00:36<01:48, 41.11it/s][A
 25%|██▌       | 1528/6000 [00:36<01:47, 41.55it/s][A
 26%|██▌       | 1533/6000 [00:36<01:47, 41.53it/s][A
 26%|██▌       | 1538/6000 [00:36<01:47, 41.51it/s][A
 26%|██▌       | 1543/6000 [00:36<01:47, 41.61it/s][A
 26%|██▌       | 1548/6000 [00:36<01:47, 41.43it/s][A
 26%|██▌       | 1553/6000 [00:36<01:47, 41.37it/s][A
 26%|██▌       | 1558/6000 [00:37<01:47, 41.25it/s][A
 26%|██▌       | 1563/6000 [00:37<01:48, 41.02it/s][A
 26%|██▌       | 1568/6000 [00:37<01:45, 41.93it/s][A
 26%|██▌       | 1573/6000 [00:37<01:46, 41.48it/s][A
 26%|██▋       | 1578/6000 [00:37<01:46, 41.56it/s][A
 26%|██▋       | 1583/6000 [00:37<01:46, 41.33it/s][A
 26%|██▋       | 1588/6000 [00:37<01:46, 41.51it/s][A
 27%|██▋       | 1593/6000 [00:37<01:47, 40.83it/s][A
 27%|██▋  

 50%|████▉     | 2989/6000 [01:11<01:13, 41.21it/s][A
 50%|████▉     | 2994/6000 [01:11<01:12, 41.38it/s][A
 50%|████▉     | 2999/6000 [01:11<01:11, 42.04it/s][A
 50%|█████     | 3004/6000 [01:11<01:11, 42.18it/s][A
 50%|█████     | 3009/6000 [01:11<01:11, 41.58it/s][A
 50%|█████     | 3014/6000 [01:11<01:11, 42.04it/s][A
 50%|█████     | 3019/6000 [01:12<01:11, 41.74it/s][A
 50%|█████     | 3024/6000 [01:12<01:12, 41.29it/s][A
 50%|█████     | 3029/6000 [01:12<01:11, 41.73it/s][A
 51%|█████     | 3034/6000 [01:12<01:11, 41.46it/s][A
 51%|█████     | 3039/6000 [01:12<01:10, 41.72it/s][A
 51%|█████     | 3044/6000 [01:12<01:11, 41.44it/s][A
 51%|█████     | 3049/6000 [01:12<01:11, 41.38it/s][A
 51%|█████     | 3054/6000 [01:12<01:10, 41.57it/s][A
 51%|█████     | 3059/6000 [01:13<01:11, 41.33it/s][A
 51%|█████     | 3064/6000 [01:13<01:10, 41.62it/s][A
 51%|█████     | 3069/6000 [01:13<01:10, 41.32it/s][A
 51%|█████     | 3074/6000 [01:13<01:10, 41.31it/s][A
 51%|█████

 74%|███████▍  | 4470/6000 [01:46<00:36, 41.45it/s][A
 75%|███████▍  | 4475/6000 [01:46<00:36, 41.44it/s][A
 75%|███████▍  | 4480/6000 [01:47<00:36, 41.61it/s][A
 75%|███████▍  | 4485/6000 [01:47<00:36, 41.23it/s][A
 75%|███████▍  | 4490/6000 [01:47<00:36, 41.63it/s][A
 75%|███████▍  | 4495/6000 [01:47<00:36, 41.38it/s][A
 75%|███████▌  | 4500/6000 [01:47<00:36, 41.60it/s][A
 75%|███████▌  | 4505/6000 [01:47<00:36, 41.42it/s][A
 75%|███████▌  | 4510/6000 [01:47<00:36, 41.36it/s][A
 75%|███████▌  | 4515/6000 [01:47<00:35, 41.57it/s][A
 75%|███████▌  | 4520/6000 [01:47<00:35, 41.38it/s][A
 75%|███████▌  | 4525/6000 [01:48<00:36, 40.82it/s][A
 76%|███████▌  | 4530/6000 [01:48<00:35, 41.12it/s][A
 76%|███████▌  | 4535/6000 [01:48<00:35, 41.16it/s][A
 76%|███████▌  | 4540/6000 [01:48<00:35, 41.18it/s][A
 76%|███████▌  | 4545/6000 [01:48<00:35, 41.45it/s][A
 76%|███████▌  | 4550/6000 [01:48<00:33, 43.01it/s][A
 76%|███████▌  | 4556/6000 [01:48<00:32, 44.89it/s][A
 76%|█████

 99%|█████████▉| 5950/6000 [02:22<00:01, 41.52it/s][A
 99%|█████████▉| 5955/6000 [02:22<00:01, 41.28it/s][A
 99%|█████████▉| 5960/6000 [02:22<00:00, 41.26it/s][A
 99%|█████████▉| 5965/6000 [02:22<00:00, 42.10it/s][A
100%|█████████▉| 5970/6000 [02:22<00:00, 42.70it/s][A
100%|█████████▉| 5975/6000 [02:22<00:00, 42.20it/s][A
100%|█████████▉| 5980/6000 [02:22<00:00, 41.84it/s][A
100%|█████████▉| 5985/6000 [02:23<00:00, 41.58it/s][A
100%|█████████▉| 5990/6000 [02:23<00:00, 41.44it/s][A
100%|█████████▉| 5995/6000 [02:23<00:00, 41.50it/s][A
100%|██████████| 6000/6000 [02:23<00:00, 41.84it/s][A


total count: 6000
Epoch 1 | Train Loss: 4.4908 | Train Accuracy: 0.1762


 40%|████      | 2/5 [04:59<07:30, 150.15s/it]

Validation Loss: 4.3519 | Validation Accuracy: 0.1800



  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 6/6000 [00:00<01:59, 50.07it/s][A
  0%|          | 12/6000 [00:00<02:11, 45.42it/s][A
  0%|          | 17/6000 [00:00<02:11, 45.39it/s][A
  0%|          | 22/6000 [00:00<02:10, 45.70it/s][A
  0%|          | 27/6000 [00:00<02:17, 43.51it/s][A
  1%|          | 32/6000 [00:00<02:19, 42.76it/s][A
  1%|          | 37/6000 [00:00<02:20, 42.49it/s][A
  1%|          | 42/6000 [00:00<02:21, 42.05it/s][A
  1%|          | 47/6000 [00:01<02:22, 41.76it/s][A
  1%|          | 52/6000 [00:01<02:22, 41.85it/s][A
  1%|          | 57/6000 [00:01<02:22, 41.72it/s][A
  1%|          | 62/6000 [00:01<02:23, 41.27it/s][A
  1%|          | 67/6000 [00:01<02:21, 41.94it/s][A
  1%|          | 72/6000 [00:01<02:22, 41.60it/s][A
  1%|▏         | 77/6000 [00:01<02:23, 41.27it/s][A
  1%|▏         | 82/6000 [00:01<02:23, 41.31it/s][A
  1%|▏         | 87/6000 [00:02<02:22, 41.63it/s][A
  2%|▏         | 92/6000 [00:02<02:22, 41.46it/s][A
  

 25%|██▌       | 1508/6000 [00:35<01:48, 41.47it/s][A
 25%|██▌       | 1513/6000 [00:36<01:48, 41.31it/s][A
 25%|██▌       | 1518/6000 [00:36<01:45, 42.31it/s][A
 25%|██▌       | 1523/6000 [00:36<01:46, 42.03it/s][A
 25%|██▌       | 1528/6000 [00:36<01:43, 43.34it/s][A
 26%|██▌       | 1533/6000 [00:36<01:39, 44.67it/s][A
 26%|██▌       | 1538/6000 [00:36<01:39, 44.64it/s][A
 26%|██▌       | 1543/6000 [00:36<01:41, 43.70it/s][A
 26%|██▌       | 1548/6000 [00:36<01:42, 43.44it/s][A
 26%|██▌       | 1553/6000 [00:36<01:44, 42.63it/s][A
 26%|██▌       | 1558/6000 [00:37<01:45, 42.09it/s][A
 26%|██▌       | 1563/6000 [00:37<01:44, 42.55it/s][A
 26%|██▌       | 1568/6000 [00:37<01:45, 41.92it/s][A
 26%|██▌       | 1573/6000 [00:37<01:44, 42.23it/s][A
 26%|██▋       | 1578/6000 [00:37<01:44, 42.44it/s][A
 26%|██▋       | 1583/6000 [00:37<01:44, 42.11it/s][A
 26%|██▋       | 1588/6000 [00:37<01:44, 42.03it/s][A
 27%|██▋       | 1593/6000 [00:37<01:47, 41.05it/s][A
 27%|██▋  

 50%|████▉     | 2990/6000 [01:11<01:11, 42.32it/s][A
 50%|████▉     | 2995/6000 [01:11<01:11, 42.16it/s][A
 50%|█████     | 3000/6000 [01:11<01:11, 41.88it/s][A
 50%|█████     | 3005/6000 [01:11<01:12, 41.52it/s][A
 50%|█████     | 3010/6000 [01:11<01:11, 41.76it/s][A
 50%|█████     | 3015/6000 [01:11<01:12, 41.21it/s][A
 50%|█████     | 3020/6000 [01:11<01:11, 41.52it/s][A
 50%|█████     | 3025/6000 [01:11<01:12, 41.13it/s][A
 50%|█████     | 3030/6000 [01:12<01:10, 42.22it/s][A
 51%|█████     | 3035/6000 [01:12<01:11, 41.74it/s][A
 51%|█████     | 3040/6000 [01:12<01:10, 41.79it/s][A
 51%|█████     | 3045/6000 [01:12<01:10, 41.66it/s][A
 51%|█████     | 3050/6000 [01:12<01:10, 41.67it/s][A
 51%|█████     | 3055/6000 [01:12<01:08, 43.04it/s][A
 51%|█████     | 3061/6000 [01:12<01:05, 44.88it/s][A
 51%|█████     | 3066/6000 [01:12<01:06, 44.38it/s][A
 51%|█████     | 3071/6000 [01:12<01:07, 43.18it/s][A
 51%|█████▏    | 3076/6000 [01:13<01:08, 42.44it/s][A
 51%|█████

 75%|███████▍  | 4471/6000 [01:46<00:36, 41.41it/s][A
 75%|███████▍  | 4476/6000 [01:46<00:36, 41.62it/s][A
 75%|███████▍  | 4481/6000 [01:46<00:36, 41.36it/s][A
 75%|███████▍  | 4486/6000 [01:46<00:36, 41.57it/s][A
 75%|███████▍  | 4491/6000 [01:46<00:36, 41.50it/s][A
 75%|███████▍  | 4496/6000 [01:47<00:36, 41.05it/s][A
 75%|███████▌  | 4501/6000 [01:47<00:36, 41.62it/s][A
 75%|███████▌  | 4506/6000 [01:47<00:35, 41.56it/s][A
 75%|███████▌  | 4511/6000 [01:47<00:35, 41.63it/s][A
 75%|███████▌  | 4516/6000 [01:47<00:35, 41.40it/s][A
 75%|███████▌  | 4521/6000 [01:47<00:36, 41.08it/s][A
 75%|███████▌  | 4526/6000 [01:47<00:35, 42.01it/s][A
 76%|███████▌  | 4531/6000 [01:47<00:35, 41.73it/s][A
 76%|███████▌  | 4536/6000 [01:47<00:35, 41.78it/s][A
 76%|███████▌  | 4541/6000 [01:48<00:33, 43.39it/s][A
 76%|███████▌  | 4546/6000 [01:48<00:32, 44.23it/s][A
 76%|███████▌  | 4551/6000 [01:48<00:33, 43.48it/s][A
 76%|███████▌  | 4556/6000 [01:48<00:33, 42.83it/s][A
 76%|█████

 99%|█████████▉| 5956/6000 [02:21<00:01, 42.95it/s][A
 99%|█████████▉| 5961/6000 [02:21<00:00, 43.70it/s][A
 99%|█████████▉| 5966/6000 [02:22<00:00, 42.97it/s][A
100%|█████████▉| 5971/6000 [02:22<00:00, 42.00it/s][A
100%|█████████▉| 5976/6000 [02:22<00:00, 42.95it/s][A
100%|█████████▉| 5981/6000 [02:22<00:00, 42.85it/s][A
100%|█████████▉| 5986/6000 [02:22<00:00, 42.04it/s][A
100%|█████████▉| 5991/6000 [02:22<00:00, 42.75it/s][A
100%|██████████| 6000/6000 [02:22<00:00, 42.00it/s][A


total count: 6000
Epoch 2 | Train Loss: 4.2553 | Train Accuracy: 0.2168


 60%|██████    | 3/5 [07:30<05:01, 150.59s/it]

Validation Loss: 4.1236 | Validation Accuracy: 0.2550



  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 6/6000 [00:00<01:51, 53.92it/s][A
  0%|          | 12/6000 [00:00<02:06, 47.34it/s][A
  0%|          | 17/6000 [00:00<02:07, 46.88it/s][A
  0%|          | 22/6000 [00:00<02:10, 45.64it/s][A
  0%|          | 27/6000 [00:00<02:12, 44.97it/s][A
  1%|          | 32/6000 [00:00<02:17, 43.52it/s][A
  1%|          | 37/6000 [00:00<02:18, 43.10it/s][A
  1%|          | 42/6000 [00:00<02:18, 42.96it/s][A
  1%|          | 47/6000 [00:01<02:19, 42.65it/s][A
  1%|          | 52/6000 [00:01<02:20, 42.26it/s][A
  1%|          | 57/6000 [00:01<02:21, 42.12it/s][A
  1%|          | 62/6000 [00:01<02:22, 41.77it/s][A
  1%|          | 67/6000 [00:01<02:23, 41.48it/s][A
  1%|          | 72/6000 [00:01<02:21, 41.76it/s][A
  1%|▏         | 77/6000 [00:01<02:22, 41.53it/s][A
  1%|▏         | 82/6000 [00:01<02:23, 41.30it/s][A
  1%|▏         | 87/6000 [00:02<02:23, 41.24it/s][A
  2%|▏         | 92/6000 [00:02<02:20, 42.04it/s][A
  

 25%|██▌       | 1509/6000 [00:35<01:44, 43.02it/s][A
 25%|██▌       | 1514/6000 [00:35<01:45, 42.41it/s][A
 25%|██▌       | 1519/6000 [00:35<01:46, 42.02it/s][A
 25%|██▌       | 1524/6000 [00:36<01:46, 42.02it/s][A
 25%|██▌       | 1529/6000 [00:36<01:47, 41.48it/s][A
 26%|██▌       | 1534/6000 [00:36<01:46, 41.95it/s][A
 26%|██▌       | 1539/6000 [00:36<01:45, 42.20it/s][A
 26%|██▌       | 1544/6000 [00:36<01:45, 42.35it/s][A
 26%|██▌       | 1549/6000 [00:36<01:46, 41.74it/s][A
 26%|██▌       | 1554/6000 [00:36<01:43, 43.16it/s][A
 26%|██▌       | 1559/6000 [00:36<01:39, 44.44it/s][A
 26%|██▌       | 1564/6000 [00:37<01:39, 44.53it/s][A
 26%|██▌       | 1569/6000 [00:37<01:38, 45.20it/s][A
 26%|██▋       | 1575/6000 [00:37<01:34, 46.97it/s][A
 26%|██▋       | 1580/6000 [00:37<01:38, 45.00it/s][A
 26%|██▋       | 1585/6000 [00:37<01:41, 43.55it/s][A
 26%|██▋       | 1590/6000 [00:37<01:41, 43.27it/s][A
 27%|██▋       | 1595/6000 [00:37<01:43, 42.51it/s][A
 27%|██▋  

 50%|████▉     | 2993/6000 [01:11<01:12, 41.44it/s][A
 50%|████▉     | 2998/6000 [01:11<01:11, 42.14it/s][A
 50%|█████     | 3003/6000 [01:11<01:10, 42.21it/s][A
 50%|█████     | 3008/6000 [01:11<01:11, 42.06it/s][A
 50%|█████     | 3013/6000 [01:11<01:08, 43.36it/s][A
 50%|█████     | 3018/6000 [01:11<01:08, 43.43it/s][A
 50%|█████     | 3023/6000 [01:11<01:09, 42.87it/s][A
 50%|█████     | 3028/6000 [01:11<01:07, 44.33it/s][A
 51%|█████     | 3033/6000 [01:11<01:04, 45.71it/s][A
 51%|█████     | 3038/6000 [01:12<01:06, 44.74it/s][A
 51%|█████     | 3043/6000 [01:12<01:06, 44.37it/s][A
 51%|█████     | 3048/6000 [01:12<01:07, 43.42it/s][A
 51%|█████     | 3053/6000 [01:12<01:08, 42.80it/s][A
 51%|█████     | 3058/6000 [01:12<01:09, 42.26it/s][A
 51%|█████     | 3063/6000 [01:12<01:10, 41.95it/s][A
 51%|█████     | 3068/6000 [01:12<01:10, 41.82it/s][A
 51%|█████     | 3073/6000 [01:12<01:09, 42.07it/s][A
 51%|█████▏    | 3078/6000 [01:12<01:08, 42.71it/s][A
 51%|█████

 75%|███████▍  | 4490/6000 [01:43<00:33, 45.18it/s][A
 75%|███████▍  | 4495/6000 [01:43<00:32, 45.79it/s][A
 75%|███████▌  | 4500/6000 [01:43<00:32, 46.08it/s][A
 75%|███████▌  | 4505/6000 [01:44<00:33, 45.21it/s][A
 75%|███████▌  | 4510/6000 [01:44<00:32, 45.78it/s][A
 75%|███████▌  | 4515/6000 [01:44<00:32, 45.73it/s][A
 75%|███████▌  | 4520/6000 [01:44<00:32, 45.75it/s][A
 75%|███████▌  | 4525/6000 [01:44<00:32, 45.56it/s][A
 76%|███████▌  | 4530/6000 [01:44<00:32, 45.38it/s][A
 76%|███████▌  | 4535/6000 [01:44<00:32, 44.81it/s][A
 76%|███████▌  | 4540/6000 [01:44<00:32, 44.85it/s][A
 76%|███████▌  | 4545/6000 [01:44<00:31, 45.62it/s][A
 76%|███████▌  | 4550/6000 [01:45<00:31, 45.50it/s][A
 76%|███████▌  | 4555/6000 [01:45<00:31, 45.40it/s][A
 76%|███████▌  | 4560/6000 [01:45<00:31, 45.00it/s][A
 76%|███████▌  | 4565/6000 [01:45<00:31, 45.10it/s][A
 76%|███████▌  | 4570/6000 [01:45<00:31, 45.26it/s][A
 76%|███████▋  | 4575/6000 [01:45<00:31, 45.25it/s][A
 76%|█████

100%|██████████| 6000/6000 [02:16<00:00, 43.98it/s][A


total count: 6000
Epoch 3 | Train Loss: 4.0550 | Train Accuracy: 0.2583


 80%|████████  | 4/5 [09:55<02:28, 148.27s/it]

Validation Loss: 3.9515 | Validation Accuracy: 0.2600



  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 6/6000 [00:00<01:42, 58.21it/s][A
  0%|          | 12/6000 [00:00<01:55, 52.01it/s][A
  0%|          | 18/6000 [00:00<02:01, 49.22it/s][A
  0%|          | 23/6000 [00:00<02:05, 47.68it/s][A
  0%|          | 28/6000 [00:00<02:07, 46.86it/s][A
  1%|          | 33/6000 [00:00<02:08, 46.32it/s][A
  1%|          | 38/6000 [00:00<02:10, 45.83it/s][A
  1%|          | 43/6000 [00:00<02:10, 45.70it/s][A
  1%|          | 48/6000 [00:01<02:10, 45.50it/s][A
  1%|          | 53/6000 [00:01<02:10, 45.44it/s][A
  1%|          | 58/6000 [00:01<02:10, 45.47it/s][A
  1%|          | 63/6000 [00:01<02:11, 45.24it/s][A
  1%|          | 68/6000 [00:01<02:11, 45.18it/s][A
  1%|          | 73/6000 [00:01<02:11, 45.15it/s][A
  1%|▏         | 78/6000 [00:01<02:07, 46.49it/s][A
  1%|▏         | 83/6000 [00:01<02:04, 47.44it/s][A
  1%|▏         | 88/6000 [00:01<02:06, 46.63it/s][A
  2%|▏         | 93/6000 [00:01<02:08, 46.08it/s][A
  

 26%|██▌       | 1536/6000 [00:33<01:34, 47.06it/s][A
 26%|██▌       | 1541/6000 [00:33<01:35, 46.47it/s][A
 26%|██▌       | 1546/6000 [00:33<01:36, 46.08it/s][A
 26%|██▌       | 1551/6000 [00:33<01:37, 45.85it/s][A
 26%|██▌       | 1556/6000 [00:33<01:37, 45.65it/s][A
 26%|██▌       | 1561/6000 [00:33<01:37, 45.31it/s][A
 26%|██▌       | 1566/6000 [00:33<01:37, 45.35it/s][A
 26%|██▌       | 1571/6000 [00:34<01:38, 45.12it/s][A
 26%|██▋       | 1576/6000 [00:34<01:37, 45.41it/s][A
 26%|██▋       | 1581/6000 [00:34<01:37, 45.27it/s][A
 26%|██▋       | 1586/6000 [00:34<01:37, 45.31it/s][A
 27%|██▋       | 1591/6000 [00:34<01:37, 45.20it/s][A
 27%|██▋       | 1596/6000 [00:34<01:37, 45.14it/s][A
 27%|██▋       | 1601/6000 [00:34<01:35, 45.93it/s][A
 27%|██▋       | 1606/6000 [00:34<01:34, 46.40it/s][A
 27%|██▋       | 1611/6000 [00:34<01:35, 45.84it/s][A
 27%|██▋       | 1616/6000 [00:35<01:35, 45.68it/s][A
 27%|██▋       | 1621/6000 [00:35<01:33, 46.80it/s][A
 27%|██▋  

 51%|█████     | 3037/6000 [01:05<01:05, 45.35it/s][A
 51%|█████     | 3042/6000 [01:06<01:05, 45.33it/s][A
 51%|█████     | 3048/6000 [01:06<01:03, 46.78it/s][A
 51%|█████     | 3053/6000 [01:06<01:03, 46.35it/s][A
 51%|█████     | 3058/6000 [01:06<01:02, 47.32it/s][A
 51%|█████     | 3063/6000 [01:06<01:02, 46.64it/s][A
 51%|█████     | 3068/6000 [01:06<01:03, 46.21it/s][A
 51%|█████     | 3073/6000 [01:06<01:03, 45.88it/s][A
 51%|█████▏    | 3078/6000 [01:06<01:03, 45.68it/s][A
 51%|█████▏    | 3083/6000 [01:06<01:04, 45.26it/s][A
 51%|█████▏    | 3088/6000 [01:07<01:04, 45.06it/s][A
 52%|█████▏    | 3093/6000 [01:07<01:04, 45.34it/s][A
 52%|█████▏    | 3098/6000 [01:07<01:03, 45.80it/s][A
 52%|█████▏    | 3103/6000 [01:07<01:02, 46.10it/s][A
 52%|█████▏    | 3108/6000 [01:07<01:02, 46.11it/s][A
 52%|█████▏    | 3113/6000 [01:07<01:03, 45.15it/s][A
 52%|█████▏    | 3118/6000 [01:07<01:03, 45.54it/s][A
 52%|█████▏    | 3123/6000 [01:07<01:03, 45.39it/s][A
 52%|█████

 76%|███████▌  | 4537/6000 [01:38<00:30, 47.45it/s][A
 76%|███████▌  | 4542/6000 [01:38<00:30, 48.13it/s][A
 76%|███████▌  | 4547/6000 [01:38<00:29, 48.63it/s][A
 76%|███████▌  | 4552/6000 [01:38<00:30, 47.52it/s][A
 76%|███████▌  | 4557/6000 [01:38<00:30, 46.81it/s][A
 76%|███████▌  | 4562/6000 [01:39<00:31, 46.16it/s][A
 76%|███████▌  | 4567/6000 [01:39<00:31, 45.78it/s][A
 76%|███████▌  | 4572/6000 [01:39<00:31, 45.16it/s][A
 76%|███████▋  | 4577/6000 [01:39<00:31, 45.24it/s][A
 76%|███████▋  | 4582/6000 [01:39<00:31, 45.64it/s][A
 76%|███████▋  | 4587/6000 [01:39<00:31, 45.49it/s][A
 77%|███████▋  | 4592/6000 [01:39<00:30, 45.43it/s][A
 77%|███████▋  | 4597/6000 [01:39<00:30, 45.44it/s][A
 77%|███████▋  | 4602/6000 [01:39<00:30, 45.42it/s][A
 77%|███████▋  | 4607/6000 [01:40<00:31, 44.17it/s][A
 77%|███████▋  | 4612/6000 [01:40<00:31, 44.41it/s][A
 77%|███████▋  | 4617/6000 [01:40<00:30, 44.62it/s][A
 77%|███████▋  | 4622/6000 [01:40<00:30, 44.75it/s][A
 77%|█████

total count: 6000
Epoch 4 | Train Loss: 3.9081 | Train Accuracy: 0.2762


100%|██████████| 5/5 [12:14<00:00, 146.84s/it]

Validation Loss: 3.8250 | Validation Accuracy: 0.2900
labels: tensor([19], device='cuda:1') torch.Size([1])
predicted: tensor([140], device='cuda:1') torch.Size([1])
labels: tensor([17], device='cuda:1') torch.Size([1])
predicted: tensor([17], device='cuda:1') torch.Size([1])
labels: tensor([31], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([30], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([32], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([129], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([131], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([130], device='cuda:1') torch.Size([1])
predicted: tensor([130], device='cuda:1') torch.Size([1])
labels: tensor([133], device='cuda:1') torch.Size([1])
pre




labels: tensor([8], device='cuda:1') torch.Size([1])
predicted: tensor([130], device='cuda:1') torch.Size([1])
labels: tensor([10], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([11], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([15], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([16], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([22], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1])
labels: tensor([28], device='cuda:1') torch.Size([1])
predicted: tensor([130], device='cuda:1') torch.Size([1])
labels: tensor([127], device='cuda:1') torch.Size([1])
predicted: tensor([1], device='cuda:1') torch.Size([1])
labels: tensor([123], device='cuda:1') torch.Size([1])
predicted: tensor([140], device='cuda:1') torch.Size([1])
la

labels: tensor([44], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([95], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([93], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([99], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([94], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([96], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([98], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([97], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([21], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: te

predicted: tensor([115], device='cuda:1') torch.Size([1])
labels: tensor([145], device='cuda:1') torch.Size([1])
predicted: tensor([145], device='cuda:1') torch.Size([1])
labels: tensor([146], device='cuda:1') torch.Size([1])
predicted: tensor([147], device='cuda:1') torch.Size([1])
labels: tensor([147], device='cuda:1') torch.Size([1])
predicted: tensor([147], device='cuda:1') torch.Size([1])
labels: tensor([75], device='cuda:1') torch.Size([1])
predicted: tensor([159], device='cuda:1') torch.Size([1])
labels: tensor([2], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([1], device='cuda:1') torch.Size([1])
predicted: tensor([1], device='cuda:1') torch.Size([1])
labels: tensor([4], device='cuda:1') torch.Size([1])
predicted: tensor([4], device='cuda:1') torch.Size([1])
labels: tensor([6], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([5], device='cuda:1') torch.Size([1])
pr

labels: tensor([134], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([136], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([138], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([141], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([142], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([143], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([144], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([148], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([149], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
l

labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0],

labels: tensor([105], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([104], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([102], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([103], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([26], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([25], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([27], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([106], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([107], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labe

labels: tensor([57], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([139], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor(

labels: tensor([80], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([81], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([85], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([84], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([82], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([86], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([87], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([88], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([90], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: te

labels: tensor([46], device='cuda:1') torch.Size([1])
predicted: tensor([46], device='cuda:1') torch.Size([1])
labels: tensor([49], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([47], device='cuda:1') torch.Size([1])
predicted: tensor([165], device='cuda:1') torch.Size([1])
labels: tensor([51], device='cuda:1') torch.Size([1])
predicted: tensor([51], device='cuda:1') torch.Size([1])
labels: tensor([50], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([53], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([52], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([54], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([55], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels

labels: tensor([147], device='cuda:1') torch.Size([1])
predicted: tensor([147], device='cuda:1') torch.Size([1])
labels: tensor([75], device='cuda:1') torch.Size([1])
predicted: tensor([159], device='cuda:1') torch.Size([1])
labels: tensor([2], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([1], device='cuda:1') torch.Size([1])
predicted: tensor([1], device='cuda:1') torch.Size([1])
labels: tensor([4], device='cuda:1') torch.Size([1])
predicted: tensor([4], device='cuda:1') torch.Size([1])
labels: tensor([6], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([5], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([3], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([91], device='cuda:1') torch.Size([1])
predicted: tensor([159], device='cuda:1') torch.Size([1])
labels: t

labels: tensor([138], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([141], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([142], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([143], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([144], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([148], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([149], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([40], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([161], device='cuda:1') torch.Size([1])
predicted: tensor([17], device='cuda:1') torch.Size([1])
l

labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0],

labels: tensor([102], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([103], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([26], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([25], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([27], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([106], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([107], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([108], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([34], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
label

labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0]

labels: tensor([85], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([84], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([82], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([86], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([87], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([88], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([90], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([89], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([100], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: t

labels: tensor([47], device='cuda:1') torch.Size([1])
predicted: tensor([165], device='cuda:1') torch.Size([1])
labels: tensor([51], device='cuda:1') torch.Size([1])
predicted: tensor([51], device='cuda:1') torch.Size([1])
labels: tensor([50], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([53], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([52], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([54], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([55], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([56], device='cuda:1') torch.Size([1])
predicted: tensor([56], device='cuda:1') torch.Size([1])
labels: tensor([73], device='cuda:1') torch.Size([1])
predicted: tensor([73], device='cuda:1') torch.Size([1])
label

predicted: tensor([159], device='cuda:1') torch.Size([1])
labels: tensor([2], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([1], device='cuda:1') torch.Size([1])
predicted: tensor([1], device='cuda:1') torch.Size([1])
labels: tensor([4], device='cuda:1') torch.Size([1])
predicted: tensor([4], device='cuda:1') torch.Size([1])
labels: tensor([6], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([5], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([3], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([91], device='cuda:1') torch.Size([1])
predicted: tensor([159], device='cuda:1') torch.Size([1])
labels: tensor([58], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([9], device='cuda:1') torch.Size([1])
predicted: te

labels: tensor([142], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([143], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([144], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([148], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([149], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([40], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([161], device='cuda:1') torch.Size([1])
predicted: tensor([17], device='cuda:1') torch.Size([1])
labels: tensor([162], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([42], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
la

predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([

labels: tensor([102], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([103], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([26], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([25], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([27], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([106], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([107], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([108], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([34], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
label

labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([0]

labels: tensor([81], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([85], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([84], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([82], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([86], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([87], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([88], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([90], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([89], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: te

labels: tensor([49], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([47], device='cuda:1') torch.Size([1])
predicted: tensor([165], device='cuda:1') torch.Size([1])
labels: tensor([51], device='cuda:1') torch.Size([1])
predicted: tensor([51], device='cuda:1') torch.Size([1])
labels: tensor([50], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([53], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([52], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([54], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([55], device='cuda:1') torch.Size([1])
predicted: tensor([0], device='cuda:1') torch.Size([1])
labels: tensor([56], device='cuda:1') torch.Size([1])
predicted: tensor([56], device='cuda:1') torch.Size([1])
labels

- 165 APs x 5 times

In [41]:
model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 54.97it/s][A
  2%|▏         | 12/495 [00:00<00:09, 49.18it/s][A
  3%|▎         | 17/495 [00:00<00:10, 47.19it/s][A
  4%|▍         | 22/495 [00:00<00:10, 46.62it/s][A
  5%|▌         | 27/495 [00:00<00:10, 46.05it/s][A
  6%|▋         | 32/495 [00:00<00:10, 45.72it/s][A
  7%|▋         | 37/495 [00:00<00:10, 44.37it/s][A
  8%|▊         | 42/495 [00:00<00:10, 44.27it/s][A
  9%|▉         | 47/495 [00:01<00:10, 44.59it/s][A
 11%|█         | 52/495 [00:01<00:09, 45.04it/s][A
 12%|█▏        | 57/495 [00:01<00:09, 44.97it/s][A
 13%|█▎        | 62/495 [00:01<00:09, 45.10it/s][A
 14%|█▎        | 67/495 [00:01<00:09, 45.13it/s][A
 15%|█▍        | 72/495 [00:01<00:09, 45.17it/s][A
 16%|█▌        | 77/495 [00:01<00:09, 45.13it/s][A
 17%|█▋        | 82/495 [00:01<00:09, 45.21it/s][A
 18%|█▊        | 87/495 [00:01<00:09, 45.15it/s][A
 19%|█▊        | 92/495 [00:02<00:08

total count: 495
Epoch 0 | Train Loss: 5.1254 | Train Accuracy: 0.0061


 20%|██        | 1/5 [00:11<00:46, 11.53s/it]

Validation Loss: 5.1192 | Validation Accuracy: 0.0788



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 56.97it/s][A
  2%|▏         | 12/495 [00:00<00:09, 50.97it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.75it/s][A
  5%|▍         | 23/495 [00:00<00:09, 47.27it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.63it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.15it/s][A
  8%|▊         | 38/495 [00:00<00:09, 45.96it/s][A
  9%|▊         | 43/495 [00:00<00:09, 45.74it/s][A
 10%|▉         | 48/495 [00:01<00:09, 45.58it/s][A
 11%|█         | 53/495 [00:01<00:09, 45.32it/s][A
 12%|█▏        | 58/495 [00:01<00:09, 45.32it/s][A
 13%|█▎        | 63/495 [00:01<00:09, 45.27it/s][A
 14%|█▎        | 68/495 [00:01<00:09, 45.23it/s][A
 15%|█▍        | 73/495 [00:01<00:09, 45.29it/s][A
 16%|█▌        | 78/495 [00:01<00:09, 45.29it/s][A
 17%|█▋        | 83/495 [00:01<00:09, 45.13it/s][A
 18%|█▊        | 88/495 [00:01<00:09, 45.14it/s][A
 19%|█▉        | 93/495 [00:02<00:08, 45.21it/s][A
 20%|█▉        | 98/4

total count: 495
Epoch 1 | Train Loss: 5.1182 | Train Accuracy: 0.0545


 40%|████      | 2/5 [00:22<00:34, 11.49s/it]

Validation Loss: 5.1099 | Validation Accuracy: 0.0545



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 59.64it/s][A
  2%|▏         | 12/495 [00:00<00:09, 50.78it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.35it/s][A
  5%|▍         | 23/495 [00:00<00:10, 47.06it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.55it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.09it/s][A
  8%|▊         | 38/495 [00:00<00:09, 47.18it/s][A
  9%|▉         | 44/495 [00:00<00:09, 48.19it/s][A
 10%|▉         | 49/495 [00:01<00:09, 47.23it/s][A
 11%|█         | 54/495 [00:01<00:09, 46.48it/s][A
 12%|█▏        | 59/495 [00:01<00:09, 46.07it/s][A
 13%|█▎        | 64/495 [00:01<00:09, 47.12it/s][A
 14%|█▍        | 69/495 [00:01<00:08, 47.73it/s][A
 15%|█▌        | 75/495 [00:01<00:08, 48.50it/s][A
 16%|█▌        | 80/495 [00:01<00:08, 47.59it/s][A
 17%|█▋        | 85/495 [00:01<00:08, 46.86it/s][A
 18%|█▊        | 90/495 [00:01<00:08, 46.32it/s][A
 19%|█▉        | 95/495 [00:02<00:08, 45.82it/s][A
 20%|██        | 100/

total count: 495
Epoch 2 | Train Loss: 5.1093 | Train Accuracy: 0.0424


 60%|██████    | 3/5 [00:34<00:23, 11.55s/it]

Validation Loss: 5.0988 | Validation Accuracy: 0.0485



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 58.93it/s][A
  2%|▏         | 12/495 [00:00<00:09, 52.69it/s][A
  4%|▎         | 18/495 [00:00<00:10, 47.60it/s][A
  5%|▍         | 23/495 [00:00<00:10, 46.37it/s][A
  6%|▌         | 28/495 [00:00<00:10, 45.84it/s][A
  7%|▋         | 33/495 [00:00<00:10, 45.89it/s][A
  8%|▊         | 38/495 [00:00<00:10, 45.54it/s][A
  9%|▊         | 43/495 [00:00<00:09, 45.41it/s][A
 10%|▉         | 48/495 [00:01<00:09, 45.46it/s][A
 11%|█         | 53/495 [00:01<00:09, 45.34it/s][A
 12%|█▏        | 58/495 [00:01<00:09, 44.99it/s][A
 13%|█▎        | 63/495 [00:01<00:09, 45.43it/s][A
 14%|█▎        | 68/495 [00:01<00:09, 44.88it/s][A
 15%|█▍        | 73/495 [00:01<00:09, 45.11it/s][A
 16%|█▌        | 78/495 [00:01<00:09, 45.29it/s][A
 17%|█▋        | 83/495 [00:01<00:09, 45.03it/s][A
 18%|█▊        | 88/495 [00:01<00:09, 45.01it/s][A
 19%|█▉        | 93/495 [00:02<00:08, 44.78it/s][A
 20%|█▉        | 98/4

total count: 495
Epoch 3 | Train Loss: 5.0985 | Train Accuracy: 0.0303


 80%|████████  | 4/5 [00:46<00:11, 11.61s/it]

Validation Loss: 5.0868 | Validation Accuracy: 0.0424



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 54.52it/s][A
  2%|▏         | 12/495 [00:00<00:09, 51.50it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.33it/s][A
  5%|▍         | 23/495 [00:00<00:10, 46.86it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.70it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.08it/s][A
  8%|▊         | 38/495 [00:00<00:10, 45.56it/s][A
  9%|▊         | 43/495 [00:00<00:09, 46.63it/s][A
 10%|▉         | 49/495 [00:01<00:09, 48.21it/s][A
 11%|█         | 54/495 [00:01<00:09, 47.01it/s][A
 12%|█▏        | 59/495 [00:01<00:09, 46.16it/s][A
 13%|█▎        | 64/495 [00:01<00:09, 46.17it/s][A
 14%|█▍        | 69/495 [00:01<00:09, 45.89it/s][A
 15%|█▍        | 74/495 [00:01<00:09, 45.44it/s][A
 16%|█▌        | 79/495 [00:01<00:09, 45.23it/s][A
 17%|█▋        | 84/495 [00:01<00:09, 45.41it/s][A
 18%|█▊        | 89/495 [00:01<00:08, 45.59it/s][A
 19%|█▉        | 94/495 [00:02<00:08, 45.48it/s][A
 20%|██        | 99/4

total count: 495
Epoch 4 | Train Loss: 5.0870 | Train Accuracy: 0.0303


100%|██████████| 5/5 [00:57<00:00, 11.59s/it]

Validation Loss: 5.0750 | Validation Accuracy: 0.0364
labels: tensor([19], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([17], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([31], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([30], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([32], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([129], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([131], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([130], device='cuda:1') torch.Size([1])
predicted: tensor([77], device='cuda:1') torch.Size([1])
labels: tensor([133], device='cuda:1') torch.Siz




labels: tensor([91], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([58], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([9], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1])
labels: tensor([8], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1])
labels: tensor([10], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([11], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([15], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([16], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([22], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1

labels: tensor([164], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([72], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([7], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([12], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([13], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([14], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([33], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([140], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([37], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.S

- 5 APs x 50 times

In [None]:
model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))