# Test of GAT
- use DGL

In [2]:
import dgl
import json
import torch
import torch as th
from tqdm import tqdm
import torch.nn as nn
from dgl.nn import GraphConv, GATConv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

- check the GPU and assign the GPU by the best memory usage

In [3]:
import subprocess
import torch

def get_free_gpu():
    try:
        # Run nvidia-smi command to get GPU details
        _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
        command = "nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader"
        memory_free_info = _output_to_list(subprocess.check_output(command.split())) 
        memory_free_values = [int(x) for i, x in enumerate(memory_free_info)]
        
        # Get the GPU with the maximum free memory
        best_gpu_id = memory_free_values.index(max(memory_free_values))
        return best_gpu_id
    except:
        # If any exception occurs, default to GPU 0 (this handles cases where nvidia-smi isn't installed)
        return 0

if torch.cuda.is_available():
    # Get the best GPU ID based on free memory and set it
    best_gpu_id = get_free_gpu()
    device = torch.device(f"cuda:{best_gpu_id}")
else:
    device = torch.device("cpu")

print(device)


cuda:0


## Fix the seed

In [4]:
import numpy as np
import torch
import random

#fix seed
def same_seeds(seed = 8787):
    torch.manual_seed(seed)
    # random.seed(seed) 
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Data Loader

In [5]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]

        g = dgl.graph((th.tensor(data["edge_index"][0]), th.tensor(data["edge_index"][1])), num_nodes=data["num_nodes"]).to(self.device)

        g.ndata['feat'] = th.tensor(data["node_feat"]).to(self.device)
        g.edata['feat'] = th.tensor(data["edge_attr"]).to(self.device)  # Add edge features to graph

        return g, th.tensor(data["label"]).to(self.device)


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


In [7]:
datasets = ['train', 'valid', 'test']
dataloaders = {}

for dataset_name in tqdm(datasets):
#     file_path = f"../data/training_data/repeated_{dataset_name}.jsonl"
#     file_path = f"../data/test_10(500times)/repeated_{dataset_name}.jsonl"
    file_path = f"../../data_processing/dgl/data/test/repeated_{dataset_name}.jsonl"
    
    print(file_path)
    with open(file_path) as f:
#         data_list = [json.loads(line) for line in f]
        data_list = [json.loads(line) for line in tqdm(f, position=0, leave=True)]
    
    dataset = GraphDataset(data_list, device)
    dataloaders[dataset_name] = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=collate)
    
print("Done!")

  0%|          | 0/3 [00:00<?, ?it/s]

../../data_processing/dgl/data/test/repeated_train.jsonl


100%|██████████| 3/3 [00:26<00:00,  8.98s/it]

../../data_processing/dgl/data/test/repeated_valid.jsonl
../../data_processing/dgl/data/test/repeated_test.jsonl
Done!





### Model

In [7]:
# class GAT(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_prob=0.25):
#         super(GAT, self).__init__()
        
#         # do not check the zero in_degree since we have all the complete graph
#         self.layer1 = GATConv(in_dim, hidden_dim, num_heads=num_heads, activation=F.relu, allow_zero_in_degree=True)
#         self.layer2 = GATConv(hidden_dim * num_heads, out_dim, num_heads=num_heads, allow_zero_in_degree=True)
        
#         # Adding Batch Normalization after each GAT layer
#         self.batchnorm1 = nn.BatchNorm1d(hidden_dim * num_heads)
#         self.batchnorm2 = nn.BatchNorm1d(out_dim)
        
#         # Adding Dropout for regularization
#         self.dropout = nn.Dropout(dropout_prob)

#     def forward(self, g, h):
#         # Apply GAT layers
#         h = self.layer1(g, h)
#         h = h.view(h.shape[0], -1)
#         h = F.relu(h)
#         h = self.dropout(h)
#         h = self.layer2(g, h).squeeze(1)
        
#         # Store the output as a new node feature
#         g.ndata['h_out'] = h

#         # Use mean pooling to aggregate this new node feature
#         h_agg = dgl.mean_nodes(g, feat='h_out')
#         return h_agg


class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_prob=0.25):
        super(GAT, self).__init__()
        
        # do not check the zero in_degree since we have all the complete graph
        self.layer1 = GATConv(in_dim, hidden_dim, num_heads=num_heads, activation=F.relu, allow_zero_in_degree=True)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads, allow_zero_in_degree=True)
        self.layer3 = GATConv(hidden_dim * num_heads, out_dim, num_heads=num_heads, allow_zero_in_degree=True)
         
        # Adding Batch Normalization after each GAT layer
        self.batchnorm1 = nn.BatchNorm1d(hidden_dim * num_heads)
        self.batchnorm2 = nn.BatchNorm1d(hidden_dim * num_heads)
        self.batchnorm3 = nn.BatchNorm1d(out_dim)
        
        # Adding Dropout for regularization
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, g, h):
        # Layer 1
        h1 = self.layer1(g, h)
        h1 = h1.view(h1.shape[0], -1)
        h1 = self.batchnorm1(h1)
        h1 = F.relu(h1)
        h1 = self.dropout(h1)
        
        # Layer 2 with residual connection
        h2 = self.layer2(g, h1)
        h2_res = h2.view(h2.shape[0], -1) + h1
        h2_res = self.batchnorm2(h2_res)
        h2_res = F.relu(h2_res)
        h2_res = self.dropout(h2_res)

        # Layer 3 with residual connection
        h3 = self.layer3(g, h2_res).squeeze(1)
        h3_res = h3 + h2_res[:, :h3.shape[1]]  # ensure the dimensions match for addition
        h3_res = self.batchnorm3(h3_res)
        h3_res = self.dropout(h3_res)

        # Aggregate
        g.ndata['h_out'] = h3_res
        h_agg = dgl.mean_nodes(g, feat='h_out')
        return h_agg
    

- Model Forward  

In [6]:
def model_fn(data, model, criterion, device, count=1):
    """Forward a batch through the model."""
    batched_g, labels = data
#     print(batch_g)
    batched_g = batched_g.to(device)
    
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT
    logits = logits.mean(dim=1)
#     print(logits)
    
    loss = criterion(logits, labels)
#     print(batched_g.ndata['feat'].dtype)
#     print("Logits shape:", logits.shape)  # Expected: (batch_size, 168)
#     print("Labels shape:", labels.shape)  # Expected: (batch_size)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
    
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

#     return loss, accuracy
    return loss, accuracy, preds

### Training

- Fix the seed and save the model.state_dict that contains the initial weight

In [11]:
seed = 8787
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
torch.save(model.state_dict(), 'model_initial/initial_weight.pth')

In [12]:
model.layer1.fc.weight

Parameter containing:
tensor([[-0.1806, -0.0598,  0.0091,  ...,  0.0719,  0.2496,  0.0873],
        [ 0.1694, -0.0015, -0.0139,  ...,  0.0147,  0.0892,  0.0146],
        [ 0.0969, -0.0595, -0.0115,  ..., -0.0474,  0.0529, -0.0565],
        ...,
        [-0.0433, -0.2248,  0.3002,  ...,  0.0850,  0.1621,  0.0422],
        [ 0.2097, -0.2492,  0.0612,  ..., -0.0041,  0.0365, -0.1483],
        [ 0.0971, -0.2221,  0.1652,  ..., -0.1312, -0.2610,  0.0077]],
       requires_grad=True)

- Check if model really load the model_dict

In [13]:
model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))
model.layer1.fc.weight

Parameter containing:
tensor([[-0.1806, -0.0598,  0.0091,  ...,  0.0719,  0.2496,  0.0873],
        [ 0.1694, -0.0015, -0.0139,  ...,  0.0147,  0.0892,  0.0146],
        [ 0.0969, -0.0595, -0.0115,  ..., -0.0474,  0.0529, -0.0565],
        ...,
        [-0.0433, -0.2248,  0.3002,  ...,  0.0850,  0.1621,  0.0422],
        [ 0.2097, -0.2492,  0.0612,  ..., -0.0041,  0.0365, -0.1483],
        [ 0.0971, -0.2221,  0.1652,  ..., -0.1312, -0.2610,  0.0077]],
       requires_grad=True)

- 6 AP x 500 times trial: 118, 120, 121, 122, 128, 83

In [39]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
#     for data in dataloaders['train']:
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")


# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # labels.size(0) is the batch size
        
        correct += (predicted == labels).sum().item()
        # (predicted == labels).sum() -> # of the matched prediction
        # .item() -> turn the tensor to the regular number
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

 20%|██        | 1/5 [01:00<04:02, 60.61s/it]

total count: 3000
Epoch 0 | Train Loss: 3.0207 | Train Accuracy: 0.1810
Validation Loss: 1.8520 | Validation Accuracy: 0.1667


 40%|████      | 2/5 [02:01<03:01, 60.62s/it]

total count: 3000
Epoch 1 | Train Loss: 1.8075 | Train Accuracy: 0.2073
Validation Loss: 1.7841 | Validation Accuracy: 0.5000


 60%|██████    | 3/5 [03:01<02:01, 60.63s/it]

total count: 3000
Epoch 2 | Train Loss: 1.7761 | Train Accuracy: 0.3123
Validation Loss: 1.7661 | Validation Accuracy: 0.8333


 80%|████████  | 4/5 [04:02<01:00, 60.73s/it]

total count: 3000
Epoch 3 | Train Loss: 1.7575 | Train Accuracy: 0.4563
Validation Loss: 1.7463 | Validation Accuracy: 0.8333


100%|██████████| 5/5 [05:03<00:00, 60.72s/it]

total count: 3000
Epoch 4 | Train Loss: 1.7330 | Train Accuracy: 0.5960
Validation Loss: 1.7179 | Validation Accuracy: 0.8333
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
Test Accuracy: 83 %





- 6 AP x 50 times trial: 118, 120, 121, 122, 128, 139

In [12]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
#     for data in tqdm(dataloaders['train']):
    for data in dataloaders['train']:
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")


# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # labels.size(0) is the batch size
        
        correct += (predicted == labels).sum().item()
        # (predicted == labels).sum() -> # of the matched prediction
        # .item() -> turn the tensor to the regular number
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]

total count: 180
Epoch 0 | Train Loss: 5.1072 | Train Accuracy: 0.2111


 20%|██        | 1/5 [00:04<00:16,  4.10s/it]

Validation Loss: 5.0831 | Validation Accuracy: 0.1667
total count: 180
Epoch 1 | Train Loss: 5.0285 | Train Accuracy: 0.2444


 40%|████      | 2/5 [00:07<00:11,  3.97s/it]

Validation Loss: 4.9493 | Validation Accuracy: 0.1667
total count: 180
Epoch 2 | Train Loss: 4.8220 | Train Accuracy: 0.1889


 60%|██████    | 3/5 [00:11<00:07,  3.93s/it]

Validation Loss: 4.6700 | Validation Accuracy: 0.1667
total count: 180
Epoch 3 | Train Loss: 4.4770 | Train Accuracy: 0.1611


 80%|████████  | 4/5 [00:15<00:03,  3.90s/it]

Validation Loss: 4.2651 | Validation Accuracy: 0.1667
total count: 180
Epoch 4 | Train Loss: 4.0337 | Train Accuracy: 0.1778


100%|██████████| 5/5 [00:19<00:00,  3.92s/it]

Validation Loss: 3.7754 | Validation Accuracy: 0.1667
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') tor




labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0')

- 6 AP x 50 times trial: 118, 120, 121, 122, 83, 139 

In [23]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
#     for data in tqdm(dataloaders['train']):
    for data in dataloaders['train']:
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")


# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # labels.size(0) is the batch size
        
        correct += (predicted == labels).sum().item()
        # (predicted == labels).sum() -> # of the matched prediction
        # .item() -> turn the tensor to the regular number
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]

total count: 180
Epoch 0 | Train Loss: 5.1086 | Train Accuracy: 0.2167


 20%|██        | 1/5 [00:01<00:06,  1.56s/it]

Validation Loss: 5.0876 | Validation Accuracy: 0.3333
total count: 180
Epoch 1 | Train Loss: 5.0450 | Train Accuracy: 0.2722


 40%|████      | 2/5 [00:03<00:04,  1.57s/it]

Validation Loss: 4.9846 | Validation Accuracy: 0.1667
total count: 180
Epoch 2 | Train Loss: 4.8851 | Train Accuracy: 0.2389


 60%|██████    | 3/5 [00:04<00:03,  1.56s/it]

Validation Loss: 4.7632 | Validation Accuracy: 0.1667
total count: 180
Epoch 3 | Train Loss: 4.6005 | Train Accuracy: 0.1667


 80%|████████  | 4/5 [00:06<00:01,  1.56s/it]

Validation Loss: 4.4211 | Validation Accuracy: 0.1667
total count: 180
Epoch 4 | Train Loss: 4.2216 | Train Accuracy: 0.1944


100%|██████████| 5/5 [00:07<00:00,  1.56s/it]

Validation Loss: 3.9941 | Validation Accuracy: 0.1667
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch




labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') t

- 5 APs x 50 times: 83, 118, 120, 121, 122

In [30]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/150 [00:00<?, ?it/s][A
  5%|▌         | 8/150 [00:00<00:01, 73.74it/s][A
 11%|█         | 16/150 [00:00<00:02, 58.19it/s][A
 15%|█▌        | 23/150 [00:00<00:02, 52.77it/s][A
 19%|█▉        | 29/150 [00:00<00:02, 50.04it/s][A
 23%|██▎       | 35/150 [00:00<00:02, 48.58it/s][A
 27%|██▋       | 40/150 [00:00<00:02, 47.38it/s][A
 30%|███       | 45/150 [00:00<00:02, 47.02it/s][A
 33%|███▎      | 50/150 [00:01<00:02, 46.46it/s][A
 37%|███▋      | 55/150 [00:01<00:02, 45.85it/s][A
 40%|████      | 60/150 [00:01<00:01, 45.75it/s][A
 43%|████▎     | 65/150 [00:01<00:01, 45.72it/s][A
 47%|████▋     | 70/150 [00:01<00:01, 45.59it/s][A
 50%|█████     | 75/150 [00:01<00:01, 45.47it/s][A
 53%|█████▎    | 80/150 [00:01<00:01, 45.30it/s][A
 57%|█████▋    | 85/150 [00:01<00:01, 45.36it/s][A
 60%|██████    | 90/150 [00:01<00:01, 45.06it/s][A
 63%|██████▎   | 95/150 [00:02<00:01, 45.37it/s][A
 67%|██████▋   | 100/150 [00:02<00:0

total count: 150
Epoch 0 | Train Loss: 5.0899 | Train Accuracy: 0.2600


 20%|██        | 1/5 [00:03<00:13,  3.47s/it]

Validation Loss: 5.0295 | Validation Accuracy: 0.2000



  0%|          | 0/150 [00:00<?, ?it/s][A
  5%|▍         | 7/150 [00:00<00:02, 57.95it/s][A
  9%|▊         | 13/150 [00:00<00:02, 50.34it/s][A
 13%|█▎        | 19/150 [00:00<00:02, 48.22it/s][A
 16%|█▌        | 24/150 [00:00<00:02, 47.10it/s][A
 19%|█▉        | 29/150 [00:00<00:02, 46.52it/s][A
 23%|██▎       | 34/150 [00:00<00:02, 45.46it/s][A
 26%|██▌       | 39/150 [00:00<00:02, 45.86it/s][A
 29%|██▉       | 44/150 [00:00<00:02, 45.78it/s][A
 33%|███▎      | 49/150 [00:01<00:02, 45.44it/s][A
 36%|███▌      | 54/150 [00:01<00:02, 45.50it/s][A
 39%|███▉      | 59/150 [00:01<00:02, 45.30it/s][A
 43%|████▎     | 64/150 [00:01<00:01, 45.21it/s][A
 46%|████▌     | 69/150 [00:01<00:01, 45.27it/s][A
 49%|████▉     | 74/150 [00:01<00:01, 45.31it/s][A
 53%|█████▎    | 79/150 [00:01<00:01, 45.22it/s][A
 56%|█████▌    | 84/150 [00:01<00:01, 45.12it/s][A
 59%|█████▉    | 89/150 [00:01<00:01, 45.26it/s][A
 63%|██████▎   | 94/150 [00:02<00:01, 45.07it/s][A
 66%|██████▌   | 99/1

total count: 150
Epoch 1 | Train Loss: 4.8697 | Train Accuracy: 0.2933


 40%|████      | 2/5 [00:06<00:10,  3.50s/it]

Validation Loss: 4.6389 | Validation Accuracy: 0.2000



  0%|          | 0/150 [00:00<?, ?it/s][A
  5%|▍         | 7/150 [00:00<00:02, 60.28it/s][A
  9%|▉         | 14/150 [00:00<00:02, 50.88it/s][A
 13%|█▎        | 20/150 [00:00<00:02, 48.87it/s][A
 17%|█▋        | 25/150 [00:00<00:02, 47.57it/s][A
 20%|██        | 30/150 [00:00<00:02, 46.65it/s][A
 23%|██▎       | 35/150 [00:00<00:02, 46.12it/s][A
 27%|██▋       | 40/150 [00:00<00:02, 45.89it/s][A
 30%|███       | 45/150 [00:00<00:02, 45.76it/s][A
 33%|███▎      | 50/150 [00:01<00:02, 45.58it/s][A
 37%|███▋      | 55/150 [00:01<00:02, 45.45it/s][A
 40%|████      | 60/150 [00:01<00:01, 45.37it/s][A
 43%|████▎     | 65/150 [00:01<00:01, 45.18it/s][A
 47%|████▋     | 70/150 [00:01<00:01, 45.33it/s][A
 50%|█████     | 75/150 [00:01<00:01, 44.84it/s][A
 53%|█████▎    | 80/150 [00:01<00:01, 44.93it/s][A
 57%|█████▋    | 86/150 [00:01<00:01, 46.85it/s][A
 61%|██████    | 91/150 [00:01<00:01, 46.32it/s][A
 64%|██████▍   | 96/150 [00:02<00:01, 45.99it/s][A
 67%|██████▋   | 101/

total count: 150
Epoch 2 | Train Loss: 4.2803 | Train Accuracy: 0.2800


 60%|██████    | 3/5 [00:10<00:06,  3.48s/it]

Validation Loss: 3.8678 | Validation Accuracy: 0.2000



  0%|          | 0/150 [00:00<?, ?it/s][A
  4%|▍         | 6/150 [00:00<00:02, 57.12it/s][A
  8%|▊         | 12/150 [00:00<00:02, 48.99it/s][A
 11%|█▏        | 17/150 [00:00<00:02, 47.95it/s][A
 15%|█▍        | 22/150 [00:00<00:02, 46.98it/s][A
 18%|█▊        | 27/150 [00:00<00:02, 46.32it/s][A
 21%|██▏       | 32/150 [00:00<00:02, 45.78it/s][A
 25%|██▍       | 37/150 [00:00<00:02, 45.73it/s][A
 28%|██▊       | 42/150 [00:00<00:02, 45.55it/s][A
 31%|███▏      | 47/150 [00:01<00:02, 45.38it/s][A
 35%|███▍      | 52/150 [00:01<00:02, 45.23it/s][A
 38%|███▊      | 57/150 [00:01<00:02, 45.23it/s][A
 41%|████▏     | 62/150 [00:01<00:01, 45.24it/s][A
 45%|████▍     | 67/150 [00:01<00:01, 45.36it/s][A
 48%|████▊     | 72/150 [00:01<00:01, 45.29it/s][A
 51%|█████▏    | 77/150 [00:01<00:01, 44.82it/s][A
 55%|█████▌    | 83/150 [00:01<00:01, 46.86it/s][A
 59%|█████▊    | 88/150 [00:01<00:01, 46.02it/s][A
 62%|██████▏   | 93/150 [00:02<00:01, 45.05it/s][A
 65%|██████▌   | 98/1

total count: 150
Epoch 3 | Train Loss: 3.4123 | Train Accuracy: 0.1933


 80%|████████  | 4/5 [00:13<00:03,  3.48s/it]

Validation Loss: 2.9617 | Validation Accuracy: 0.2000



  0%|          | 0/150 [00:00<?, ?it/s][A
  4%|▍         | 6/150 [00:00<00:02, 59.13it/s][A
  8%|▊         | 12/150 [00:00<00:02, 54.42it/s][A
 12%|█▏        | 18/150 [00:00<00:02, 49.99it/s][A
 16%|█▌        | 24/150 [00:00<00:02, 48.23it/s][A
 19%|█▉        | 29/150 [00:00<00:02, 47.44it/s][A
 23%|██▎       | 34/150 [00:00<00:02, 46.57it/s][A
 26%|██▌       | 39/150 [00:00<00:02, 46.00it/s][A
 29%|██▉       | 44/150 [00:00<00:02, 45.78it/s][A
 33%|███▎      | 49/150 [00:01<00:02, 45.62it/s][A
 36%|███▌      | 54/150 [00:01<00:02, 46.02it/s][A
 39%|███▉      | 59/150 [00:01<00:01, 46.56it/s][A
 43%|████▎     | 64/150 [00:01<00:01, 46.06it/s][A
 46%|████▌     | 69/150 [00:01<00:01, 45.66it/s][A
 49%|████▉     | 74/150 [00:01<00:01, 45.31it/s][A
 53%|█████▎    | 79/150 [00:01<00:01, 45.70it/s][A
 56%|█████▌    | 84/150 [00:01<00:01, 46.82it/s][A
 59%|█████▉    | 89/150 [00:01<00:01, 47.70it/s][A
 63%|██████▎   | 94/150 [00:02<00:01, 46.89it/s][A
 66%|██████▌   | 99/1

total count: 150
Epoch 4 | Train Loss: 2.5998 | Train Accuracy: 0.2200


100%|██████████| 5/5 [00:17<00:00,  3.48s/it]

Validation Loss: 2.2961 | Validation Accuracy: 0.2000
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch




labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') t

- 5 APs x 500 times: 83, 118, 120, 121, 122

In [36]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  1%|          | 8/1500 [00:00<00:21, 70.90it/s][A
  1%|          | 16/1500 [00:00<00:27, 54.01it/s][A
  1%|▏         | 22/1500 [00:00<00:29, 50.52it/s][A
  2%|▏         | 28/1500 [00:00<00:30, 48.37it/s][A
  2%|▏         | 33/1500 [00:00<00:30, 47.94it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 47.06it/s][A
  3%|▎         | 43/1500 [00:00<00:31, 46.33it/s][A
  3%|▎         | 49/1500 [00:00<00:30, 47.60it/s][A
  4%|▎         | 54/1500 [00:01<00:31, 46.23it/s][A
  4%|▍         | 59/1500 [00:01<00:31, 46.40it/s][A
  4%|▍         | 64/1500 [00:01<00:31, 46.18it/s][A
  5%|▍         | 69/1500 [00:01<00:31, 45.88it/s][A
  5%|▍         | 74/1500 [00:01<00:31, 45.13it/s][A
  5%|▌         | 79/1500 [00:01<00:31, 44.79it/s][A
  6%|▌         | 84/1500 [00:01<00:30, 45.74it/s][A
  6%|▌         | 89/1500 [00:01<00:31, 45.47it/s][A
  6%|▋         | 94/1500 [00:01<00:30, 45.91it/s][A
  7%|▋         | 9

total count: 1500
Epoch 0 | Train Loss: 2.9499 | Train Accuracy: 0.2360


 20%|██        | 1/5 [00:34<02:19, 34.90s/it]

Validation Loss: 1.6792 | Validation Accuracy: 0.2000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:25, 58.76it/s][A
  1%|          | 12/1500 [00:00<00:28, 51.52it/s][A
  1%|          | 18/1500 [00:00<00:30, 49.00it/s][A
  2%|▏         | 23/1500 [00:00<00:29, 49.34it/s][A
  2%|▏         | 28/1500 [00:00<00:30, 48.95it/s][A
  2%|▏         | 33/1500 [00:00<00:30, 47.93it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 47.10it/s][A
  3%|▎         | 43/1500 [00:00<00:31, 46.58it/s][A
  3%|▎         | 48/1500 [00:01<00:31, 46.06it/s][A
  4%|▎         | 53/1500 [00:01<00:31, 45.94it/s][A
  4%|▍         | 58/1500 [00:01<00:31, 45.64it/s][A
  4%|▍         | 63/1500 [00:01<00:31, 45.50it/s][A
  5%|▍         | 68/1500 [00:01<00:31, 45.26it/s][A
  5%|▍         | 73/1500 [00:01<00:32, 44.56it/s][A
  5%|▌         | 78/1500 [00:01<00:31, 45.30it/s][A
  6%|▌         | 83/1500 [00:01<00:31, 45.54it/s][A
  6%|▌         | 88/1500 [00:01<00:31, 45.32it/s][A
  6%|▌         | 93/1500 [00:01<00:30, 45.99it/s][A
  

total count: 1500
Epoch 1 | Train Loss: 1.6329 | Train Accuracy: 0.2533


 40%|████      | 2/5 [01:09<01:44, 34.86s/it]

Validation Loss: 1.6067 | Validation Accuracy: 0.8000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:26, 56.49it/s][A
  1%|          | 12/1500 [00:00<00:30, 49.54it/s][A
  1%|          | 18/1500 [00:00<00:30, 47.96it/s][A
  2%|▏         | 23/1500 [00:00<00:31, 46.78it/s][A
  2%|▏         | 28/1500 [00:00<00:31, 46.35it/s][A
  2%|▏         | 33/1500 [00:00<00:32, 45.82it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 45.80it/s][A
  3%|▎         | 43/1500 [00:00<00:32, 45.46it/s][A
  3%|▎         | 48/1500 [00:01<00:32, 45.25it/s][A
  4%|▎         | 53/1500 [00:01<00:31, 45.33it/s][A
  4%|▍         | 58/1500 [00:01<00:31, 45.76it/s][A
  4%|▍         | 63/1500 [00:01<00:31, 46.12it/s][A
  5%|▍         | 68/1500 [00:01<00:31, 46.02it/s][A
  5%|▍         | 73/1500 [00:01<00:31, 45.92it/s][A
  5%|▌         | 78/1500 [00:01<00:31, 45.70it/s][A
  6%|▌         | 83/1500 [00:01<00:31, 45.53it/s][A
  6%|▌         | 88/1500 [00:01<00:31, 45.45it/s][A
  6%|▌         | 93/1500 [00:02<00:31, 45.37it/s][A
  

total count: 1500
Epoch 2 | Train Loss: 1.5974 | Train Accuracy: 0.3327


 60%|██████    | 3/5 [01:44<01:09, 34.91s/it]

Validation Loss: 1.5876 | Validation Accuracy: 0.8000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:25, 59.51it/s][A
  1%|          | 12/1500 [00:00<00:28, 51.77it/s][A
  1%|          | 18/1500 [00:00<00:28, 51.16it/s][A
  2%|▏         | 24/1500 [00:00<00:29, 50.34it/s][A
  2%|▏         | 30/1500 [00:00<00:29, 50.50it/s][A
  2%|▏         | 36/1500 [00:00<00:30, 48.65it/s][A
  3%|▎         | 41/1500 [00:00<00:30, 47.73it/s][A
  3%|▎         | 46/1500 [00:00<00:30, 46.94it/s][A
  3%|▎         | 51/1500 [00:01<00:30, 47.73it/s][A
  4%|▎         | 56/1500 [00:01<00:29, 48.37it/s][A
  4%|▍         | 61/1500 [00:01<00:30, 47.31it/s][A
  4%|▍         | 66/1500 [00:01<00:30, 46.68it/s][A
  5%|▍         | 71/1500 [00:01<00:30, 46.19it/s][A
  5%|▌         | 76/1500 [00:01<00:31, 45.88it/s][A
  5%|▌         | 81/1500 [00:01<00:31, 45.39it/s][A
  6%|▌         | 86/1500 [00:01<00:31, 45.46it/s][A
  6%|▌         | 91/1500 [00:01<00:30, 45.63it/s][A
  6%|▋         | 96/1500 [00:02<00:30, 45.41it/s][A
  

total count: 1500
Epoch 3 | Train Loss: 1.5800 | Train Accuracy: 0.4307


 80%|████████  | 4/5 [02:19<00:34, 34.89s/it]

Validation Loss: 1.5683 | Validation Accuracy: 0.8000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:25, 57.78it/s][A
  1%|          | 12/1500 [00:00<00:29, 49.98it/s][A
  1%|          | 18/1500 [00:00<00:30, 48.15it/s][A
  2%|▏         | 23/1500 [00:00<00:31, 47.01it/s][A
  2%|▏         | 28/1500 [00:00<00:31, 46.29it/s][A
  2%|▏         | 33/1500 [00:00<00:31, 45.90it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 45.80it/s][A
  3%|▎         | 43/1500 [00:00<00:32, 45.49it/s][A
  3%|▎         | 48/1500 [00:01<00:31, 45.52it/s][A
  4%|▎         | 53/1500 [00:01<00:31, 45.23it/s][A
  4%|▍         | 58/1500 [00:01<00:31, 45.37it/s][A
  4%|▍         | 63/1500 [00:01<00:32, 44.69it/s][A
  5%|▍         | 68/1500 [00:01<00:31, 45.15it/s][A
  5%|▍         | 74/1500 [00:01<00:30, 46.85it/s][A
  5%|▌         | 79/1500 [00:01<00:30, 46.19it/s][A
  6%|▌         | 84/1500 [00:01<00:30, 45.98it/s][A
  6%|▌         | 89/1500 [00:01<00:30, 45.74it/s][A
  6%|▋         | 94/1500 [00:02<00:31, 45.25it/s][A
  

total count: 1500
Epoch 4 | Train Loss: 1.5565 | Train Accuracy: 0.5833


100%|██████████| 5/5 [02:54<00:00, 34.92s/it]

Validation Loss: 1.5401 | Validation Accuracy: 0.8000
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch




labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') t

labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') t

labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') t

labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') t

labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') t

labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') 

labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') t

- 5 APs x 500 times: 128, 118, 120, 121, 122

In [17]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  1%|          | 13/1500 [00:00<00:11, 125.60it/s][A
  2%|▏         | 27/1500 [00:00<00:11, 129.20it/s][A
  3%|▎         | 40/1500 [00:00<00:11, 129.43it/s][A
  4%|▎         | 53/1500 [00:00<00:11, 129.00it/s][A
  4%|▍         | 66/1500 [00:00<00:11, 129.28it/s][A
  5%|▌         | 80/1500 [00:00<00:10, 129.59it/s][A
  6%|▌         | 93/1500 [00:00<00:10, 129.54it/s][A
  7%|▋         | 106/1500 [00:00<00:10, 129.57it/s][A
  8%|▊         | 119/1500 [00:00<00:10, 129.50it/s][A
  9%|▉         | 132/1500 [00:01<00:10, 129.17it/s][A
 10%|▉         | 145/1500 [00:01<00:10, 128.69it/s][A
 11%|█         | 158/1500 [00:01<00:10, 128.75it/s][A
 11%|█▏        | 171/1500 [00:01<00:10, 128.80it/s][A
 12%|█▏        | 184/1500 [00:01<00:10, 128.72it/s][A
 13%|█▎        | 197/1500 [00:01<00:10, 128.90it/s][A
 14%|█▍        | 210/1500 [00:01<00:10, 128.50it/s][A
 15%|█▍        | 223/1500 [00:01<00:09, 128.8

total count: 1500
Epoch 0 | Train Loss: 2.8167 | Train Accuracy: 0.2267


 20%|██        | 1/5 [00:13<00:54, 13.70s/it]

Validation Loss: 1.6525 | Validation Accuracy: 0.2000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:28, 52.36it/s][A
  1%|          | 12/1500 [00:00<00:29, 50.99it/s][A
  1%|          | 18/1500 [00:00<00:30, 48.57it/s][A
  2%|▏         | 23/1500 [00:00<00:31, 46.71it/s][A
  2%|▏         | 28/1500 [00:00<00:31, 46.53it/s][A
  2%|▏         | 34/1500 [00:00<00:30, 48.13it/s][A
  3%|▎         | 39/1500 [00:00<00:31, 47.08it/s][A
  3%|▎         | 44/1500 [00:00<00:31, 46.31it/s][A
  3%|▎         | 49/1500 [00:01<00:31, 46.71it/s][A
  4%|▎         | 54/1500 [00:01<00:30, 46.68it/s][A
  4%|▍         | 59/1500 [00:01<00:31, 46.29it/s][A
  4%|▍         | 64/1500 [00:01<00:31, 46.11it/s][A
  5%|▍         | 69/1500 [00:01<00:31, 45.72it/s][A
  5%|▍         | 74/1500 [00:01<00:31, 45.52it/s][A
  5%|▌         | 79/1500 [00:01<00:31, 45.48it/s][A
  6%|▌         | 84/1500 [00:01<00:31, 45.44it/s][A
  6%|▌         | 89/1500 [00:01<00:30, 46.20it/s][A
  6%|▋         | 95/1500 [00:02<00:29, 47.68it/s][A
  

total count: 1500
Epoch 1 | Train Loss: 1.6262 | Train Accuracy: 0.2533


 40%|████      | 2/5 [00:48<01:17, 25.99s/it]

Validation Loss: 1.6104 | Validation Accuracy: 0.2000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:25, 59.56it/s][A
  1%|          | 12/1500 [00:00<00:29, 50.41it/s][A
  1%|          | 18/1500 [00:00<00:30, 48.08it/s][A
  2%|▏         | 23/1500 [00:00<00:31, 47.05it/s][A
  2%|▏         | 28/1500 [00:00<00:31, 46.50it/s][A
  2%|▏         | 33/1500 [00:00<00:32, 45.74it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 45.75it/s][A
  3%|▎         | 43/1500 [00:00<00:31, 46.96it/s][A
  3%|▎         | 48/1500 [00:01<00:30, 47.78it/s][A
  4%|▎         | 53/1500 [00:01<00:30, 48.17it/s][A
  4%|▍         | 58/1500 [00:01<00:29, 48.60it/s][A
  4%|▍         | 63/1500 [00:01<00:30, 47.61it/s][A
  5%|▍         | 68/1500 [00:01<00:30, 46.87it/s][A
  5%|▍         | 73/1500 [00:01<00:29, 47.73it/s][A
  5%|▌         | 79/1500 [00:01<00:29, 48.35it/s][A
  6%|▌         | 84/1500 [00:01<00:29, 47.22it/s][A
  6%|▌         | 89/1500 [00:01<00:30, 46.61it/s][A
  6%|▋         | 94/1500 [00:01<00:30, 46.26it/s][A
  

total count: 1500
Epoch 2 | Train Loss: 1.6033 | Train Accuracy: 0.3813


 60%|██████    | 3/5 [01:23<00:59, 29.99s/it]

Validation Loss: 1.5957 | Validation Accuracy: 1.0000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:26, 55.62it/s][A
  1%|          | 12/1500 [00:00<00:30, 48.08it/s][A
  1%|          | 17/1500 [00:00<00:31, 47.45it/s][A
  2%|▏         | 23/1500 [00:00<00:30, 48.76it/s][A
  2%|▏         | 28/1500 [00:00<00:31, 47.14it/s][A
  2%|▏         | 33/1500 [00:00<00:31, 46.73it/s][A
  3%|▎         | 38/1500 [00:00<00:31, 46.33it/s][A
  3%|▎         | 43/1500 [00:00<00:31, 45.97it/s][A
  3%|▎         | 48/1500 [00:01<00:31, 45.60it/s][A
  4%|▎         | 53/1500 [00:01<00:31, 45.46it/s][A
  4%|▍         | 58/1500 [00:01<00:31, 45.52it/s][A
  4%|▍         | 63/1500 [00:01<00:31, 45.32it/s][A
  5%|▍         | 68/1500 [00:01<00:31, 45.21it/s][A
  5%|▍         | 73/1500 [00:01<00:31, 44.86it/s][A
  5%|▌         | 78/1500 [00:01<00:30, 46.06it/s][A
  6%|▌         | 83/1500 [00:01<00:30, 46.51it/s][A
  6%|▌         | 88/1500 [00:01<00:30, 46.02it/s][A
  6%|▌         | 93/1500 [00:02<00:30, 45.76it/s][A
  

total count: 1500
Epoch 3 | Train Loss: 1.5874 | Train Accuracy: 0.5320


 80%|████████  | 4/5 [01:57<00:31, 31.82s/it]

Validation Loss: 1.5751 | Validation Accuracy: 1.0000



  0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 6/1500 [00:00<00:25, 58.90it/s][A
  1%|          | 12/1500 [00:00<00:27, 54.27it/s][A
  1%|          | 18/1500 [00:00<00:29, 50.10it/s][A
  2%|▏         | 24/1500 [00:00<00:30, 48.20it/s][A
  2%|▏         | 29/1500 [00:00<00:31, 47.19it/s][A
  2%|▏         | 34/1500 [00:00<00:30, 48.03it/s][A
  3%|▎         | 40/1500 [00:00<00:30, 48.67it/s][A
  3%|▎         | 45/1500 [00:00<00:30, 47.57it/s][A
  3%|▎         | 50/1500 [00:01<00:30, 46.84it/s][A
  4%|▎         | 55/1500 [00:01<00:31, 46.23it/s][A
  4%|▍         | 60/1500 [00:01<00:31, 45.87it/s][A
  4%|▍         | 66/1500 [00:01<00:30, 47.50it/s][A
  5%|▍         | 71/1500 [00:01<00:30, 47.44it/s][A
  5%|▌         | 77/1500 [00:01<00:29, 48.58it/s][A
  5%|▌         | 82/1500 [00:01<00:29, 47.56it/s][A
  6%|▌         | 87/1500 [00:01<00:30, 46.69it/s][A
  6%|▌         | 92/1500 [00:01<00:30, 46.24it/s][A
  7%|▋         | 98/1500 [00:02<00:29, 47.56it/s][A
  

total count: 1500
Epoch 4 | Train Loss: 1.5602 | Train Accuracy: 0.7287


100%|██████████| 5/5 [02:32<00:00, 30.48s/it]

Validation Loss: 1.5405 | Validation Accuracy: 1.0000
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') tor




labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0')

labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0')

labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0')

labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0')

labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0')

labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0')

- 165 APs + benign x 50 times

In [8]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 12


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
#     for data in tqdm(dataloaders['train']):
    for data in dataloaders['train']:
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")


# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # labels.size(0) is the batch size
        
        correct += (predicted == labels).sum().item()
        # (predicted == labels).sum() -> # of the matched prediction
        # .item() -> turn the tensor to the regular number
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/12 [00:00<?, ?it/s]

total count: 24000
Epoch 6 | Train Loss: 4.2231 | Train Accuracy: 0.2003


 58%|█████▊    | 7/12 [1:00:18<43:03, 516.79s/it]

Validation Loss: 4.1661 | Validation Accuracy: 0.2126
total count: 24000
Epoch 7 | Train Loss: 4.1462 | Train Accuracy: 0.2112


 67%|██████▋   | 8/12 [1:08:55<34:27, 516.89s/it]

Validation Loss: 4.0871 | Validation Accuracy: 0.2223
total count: 24000
Epoch 8 | Train Loss: 4.0734 | Train Accuracy: 0.2235


 75%|███████▌  | 9/12 [1:17:31<25:50, 516.81s/it]

Validation Loss: 4.0145 | Validation Accuracy: 0.2370
total count: 24000
Epoch 9 | Train Loss: 4.0052 | Train Accuracy: 0.2324


 83%|████████▎ | 10/12 [1:26:08<17:13, 516.85s/it]

Validation Loss: 3.9474 | Validation Accuracy: 0.2462
total count: 24000
Epoch 10 | Train Loss: 3.9453 | Train Accuracy: 0.2429


 92%|█████████▏| 11/12 [1:34:45<08:36, 516.84s/it]

Validation Loss: 3.8858 | Validation Accuracy: 0.2629
total count: 24000
Epoch 11 | Train Loss: 3.8886 | Train Accuracy: 0.2525


100%|██████████| 12/12 [1:43:22<00:00, 516.85s/it]

Validation Loss: 3.8293 | Validation Accuracy: 0.2681
labels: tensor([88], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([100], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([29], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicte




labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([125], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([9], device='cuda:0') torch.Size([1])
predicted: tensor([9], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tens

labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([53], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([28], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([36], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([9], device='cuda:0') torch.Size([1])
predicted: tensor([9], device='cuda:0') torch.Size([1])
labels: tensor([97], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([70], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([82], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([55], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([151], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([109], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([57], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([55], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([86], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([12], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([71], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([150], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([71], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
label

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([151], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([149], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([5], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([46], device='cuda:0') torch.Size([1])
predicted: tensor([46], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([39], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([30], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([157], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([136], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([42], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([37], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([8], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([166], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([2], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([29], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([21], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tenso

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([74], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([30], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([43], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([136], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([45], device='cuda:0') torch.Size([1])
predicted: tensor([45], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
pred

labels: tensor([59], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([136], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([88], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([77], device='cuda:0') torch.Size([1])
predicted: tensor([77], device='cuda:0') torch.Size([1])
labels: tensor([97], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([112], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([145], device='cuda:0') torch.Size([1])
predicted: tensor([145], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([74], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
label

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([88], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([126], device='cuda:0') torch.Size([1])
predicted: tensor([126], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([84], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([141], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([82], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([27], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([18], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([48], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
label

labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([65], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([48], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([45], device='cuda:0') torch.Size([1])
predicted: tensor([45], device='cuda:0') torch.Size([1])
labels: tensor([4], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([58], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([96], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

predicted: tensor([77], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([149], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([8], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
pred

labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([157], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([16], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([142], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels: tensor([21], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
lab

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([61], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([61], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([54], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([124], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([85], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([158], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([53], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([66], device='cuda:0') torch.Size([1])
predic

predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([115], device='cuda:0') torch.Size([1])
predicted: tensor([115], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([102], device='cuda:0') torch.Size([1])
predicted:

labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([109], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([22], device='cuda:0') torch.Size([1])
predicted: tensor([22], device='cuda:0') torch.Size([1])
labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([59], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tens

labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([158], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([59], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([13], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([111], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([91], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([19], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([62], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([133], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

predicted: tensor([73], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([84], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([163], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([34], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicte

labels: tensor([95], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([99], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([160], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([93], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([85], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([140], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([95], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([126], device='cuda:0') torch.Size([1])
predicted: tensor([126], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels: tensor([145], device='cuda:0') torch.Size([1])
predicted: tensor([145], device='cuda:0') torch.Size([1])


labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([62], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([4], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([37], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([58], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([75], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([30], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([34], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([46], device='cuda:0') torch.Size([1])
predicted: tensor([46], device='cuda:0') torch.Size([1])
labels: tensor([58], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([15], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predic

labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([150], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([68], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([93], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([123], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([31], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([55], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([20], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predict

labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([18], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([77], device='cuda:0') torch.Size([1])
predicted: tensor([77], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([4], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
label

labels: tensor([49], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([126], device='cuda:0') torch.Size([1])
predicted: tensor([126], device='cuda:0') torch.Size([1])
labels: tensor([17], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([12], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
la

labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([123], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([166], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([17], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([74], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([97], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([125], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([125], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([99], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([16], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([79], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([59], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([162], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([140], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([126], device='cuda:0') torch.Size([1])
predicted: tensor([126], device='cuda:0') torch.Size([1])
labels: tensor([55], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([82], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([15], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([48], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([124], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([97], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([132], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([151], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([105], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([16], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([106], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([44], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([123], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([41], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([163], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
label

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([140], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([66], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([96], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([13], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([85], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([47], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([85], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([57], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([63], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tenso

labels: tensor([93], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([34], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([78], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([73], device='cuda:0') torch.Size([1])
predicted: tensor([73], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([66], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([149], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([16], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([86], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([41], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([123], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([53], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([137], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([58], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([95], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([14], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([79], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([13], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([5], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([19], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([27], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([44], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([21], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([43], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([8], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([71], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([57], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([104], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([57], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([158], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([163], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([155], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([102], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([36], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([115], device='cuda:0') torch.Size([1])
predicted: tensor([115], device='cuda:0') torch.Size([1])
labe

labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([39], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([42], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([140], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([34], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([15], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([148], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([63], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([115], device='cuda:0') torch.Size([1])
predicted: tensor([115], device='cuda:0') torch.Size([1])
labels: tensor([9], device='cuda:0') torch.Size([1])
predicted: tensor([9], device='cuda:0') torch.Size([1])
labels: tensor([31], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([7], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted:

labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([99], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([19], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([47], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([45], device='cuda:0') torch.Size([1])
predicted: tensor([45], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([65], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([94], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([25], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([8], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([58], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([141], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([99], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([26], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([41], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([66], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([53], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([32], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([73], device='cuda:0') torch.Size([1])
predicted: tensor([73], device='cuda:0') torch.Size([1])
labels: tensor([21], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([102], device='cuda:0') torch.Size([1])
predicted

predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([157], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([63], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([141], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicte

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([143], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([56], device='cuda:0') torch.Size([1])
predicted: tensor([56], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([37], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([12], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([134], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([137], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([70], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([54], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([27], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([94], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([12], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([17], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([91], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([88], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([137], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([20], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([132], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([104], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([115], device='cuda:0') torch.Size([1])
predicted: tensor([115], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([150], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([86], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([31], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([29], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([4], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([20], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([132], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([1], device='cuda:0') torch.Size([1])
predicted: tensor([1], device='cuda:0') torch.Size([1])
labels: tensor([123], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([62], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([82], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([79], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
lab

labels: tensor([32], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([41], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([55], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([151], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([7], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([15], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: ten

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([148], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([29], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([28], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([157], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([24], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([161], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([155], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([41], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([27], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([149], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([70], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([164], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([45], device='cuda:0') torch.Size([1])
predicted: tensor([45], device='cuda:0') torch.Size([1])
labels: tensor([91], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([155], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([63], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([109], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([42], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([144], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels:

labels: tensor([88], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([43], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([117], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([56], device='cuda:0') torch.Size([1])
predicted: tensor([56], device='cuda:0') torch.Size([1])
labels: tensor([138], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([98], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([111], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([43], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([79], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([147], device='cuda:0') torch.Size([1])
predicted: tensor([147], device='cuda:0') torch.Size([1])
labels: tensor([111], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([18], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([134], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([158], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([86], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([62], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([89], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([159], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([132], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([115], device='cuda:0') torch.Size([1])
predicted: tensor([115], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([9], device='cuda:0') torch.Size([1])
predicted: tensor([9], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([46], device='cuda:0') torch.Size([1])
predicted: tensor([46], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([67], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([33], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([153], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([124], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([111], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([165], device='cuda:0') torch.Size([1])
predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([152], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([14], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([73], device='cuda:0') torch.Size([1])
predicted: tensor([73], device='cuda:0') torch.Size([1])
labels: tensor([14], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: t

labels: tensor([94], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([28], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([148], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([163], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([2], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([39], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([87], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([49], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([60], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tens

predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([51], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='cuda:0') torch.Size([1])
labels: tensor([40], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([109], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([14], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
pred

labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([49], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([124], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([112], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([127], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([10], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([69], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([141], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([4], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([124], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: te

labels: tensor([136], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([73], device='cuda:0') torch.Size([1])
predicted: tensor([73], device='cuda:0') torch.Size([1])
labels: tensor([157], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([148], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([101], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([140], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
lab

labels: tensor([25], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([72], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([116], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([45], device='cuda:0') torch.Size([1])
predicted: tensor([45], device='cuda:0') torch.Size([1])
labels: tensor([22], device='cuda:0') torch.Size([1])
predicted: tensor([22], device='cuda:0') torch.Size([1])
labels: tensor([145], device='cuda:0') torch.Size([1])
predicted: tensor([145], device='cuda:0') torch.Size([1])
labels: tensor([5], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
lab

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([70], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([131], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([49], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([39], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([77], device='cuda:0') torch.Size([1])
predicted: tensor([77], device='cuda:0') torch.Size([1])
labels: tensor([47], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([20], device='cuda:0') torch.Size([1])
predicte

labels: tensor([152], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([148], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([91], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([134], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([25], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([141], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([130], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
label

labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([134], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([37], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([49], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([114], device='cuda:0') torch.Size([1])
predicted: tensor([18], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([53], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([102], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([39], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([2], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([133], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([59], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([112], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([65], device='cuda:0') torch.Size([1])
predicted: te

labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([106], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([103], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([11], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([90], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: 

labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([94], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([35], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([54], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([91], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([166], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([146], device='cuda:0') torch.Size([1])
predicted: tensor([146], device='cuda:0') torch.Size([1])
labels:

labels: tensor([107], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([113], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([3], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([108], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([152], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([47], device='cuda:0') torch.Size([1])
predicted: tensor([37], device='cuda:0') torch.Size([1])
labels: tensor([50], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([167], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([154], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labe

labels: tensor([93], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([64], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([6], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([81], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([36], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([38], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([23], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([19], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tens

labels: tensor([110], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([160], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([135], device='cuda:0') torch.Size([1])
predicted: tensor([4], device='cuda:0') torch.Size([1])
labels: tensor([44], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([52], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([162], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([126], device='cuda:0') torch.Size([1])
predicted: tensor([126], device='cuda:0') torch.Size([1])
lab

predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([84], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([19], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([80], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([119], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([36], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([14], device='cuda:0') torch.Size([1])
predicted: t

predicted: tensor([165], device='cuda:0') torch.Size([1])
labels: tensor([8], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([155], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([125], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([156], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([82], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([111], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([129], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([0], device='cuda:0') torch.Size([1])
predicted: tensor([0], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predi

- 165 APs x 5 times

In [41]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 54.97it/s][A
  2%|▏         | 12/495 [00:00<00:09, 49.18it/s][A
  3%|▎         | 17/495 [00:00<00:10, 47.19it/s][A
  4%|▍         | 22/495 [00:00<00:10, 46.62it/s][A
  5%|▌         | 27/495 [00:00<00:10, 46.05it/s][A
  6%|▋         | 32/495 [00:00<00:10, 45.72it/s][A
  7%|▋         | 37/495 [00:00<00:10, 44.37it/s][A
  8%|▊         | 42/495 [00:00<00:10, 44.27it/s][A
  9%|▉         | 47/495 [00:01<00:10, 44.59it/s][A
 11%|█         | 52/495 [00:01<00:09, 45.04it/s][A
 12%|█▏        | 57/495 [00:01<00:09, 44.97it/s][A
 13%|█▎        | 62/495 [00:01<00:09, 45.10it/s][A
 14%|█▎        | 67/495 [00:01<00:09, 45.13it/s][A
 15%|█▍        | 72/495 [00:01<00:09, 45.17it/s][A
 16%|█▌        | 77/495 [00:01<00:09, 45.13it/s][A
 17%|█▋        | 82/495 [00:01<00:09, 45.21it/s][A
 18%|█▊        | 87/495 [00:01<00:09, 45.15it/s][A
 19%|█▊        | 92/495 [00:02<00:08

total count: 495
Epoch 0 | Train Loss: 5.1254 | Train Accuracy: 0.0061


 20%|██        | 1/5 [00:11<00:46, 11.53s/it]

Validation Loss: 5.1192 | Validation Accuracy: 0.0788



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 56.97it/s][A
  2%|▏         | 12/495 [00:00<00:09, 50.97it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.75it/s][A
  5%|▍         | 23/495 [00:00<00:09, 47.27it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.63it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.15it/s][A
  8%|▊         | 38/495 [00:00<00:09, 45.96it/s][A
  9%|▊         | 43/495 [00:00<00:09, 45.74it/s][A
 10%|▉         | 48/495 [00:01<00:09, 45.58it/s][A
 11%|█         | 53/495 [00:01<00:09, 45.32it/s][A
 12%|█▏        | 58/495 [00:01<00:09, 45.32it/s][A
 13%|█▎        | 63/495 [00:01<00:09, 45.27it/s][A
 14%|█▎        | 68/495 [00:01<00:09, 45.23it/s][A
 15%|█▍        | 73/495 [00:01<00:09, 45.29it/s][A
 16%|█▌        | 78/495 [00:01<00:09, 45.29it/s][A
 17%|█▋        | 83/495 [00:01<00:09, 45.13it/s][A
 18%|█▊        | 88/495 [00:01<00:09, 45.14it/s][A
 19%|█▉        | 93/495 [00:02<00:08, 45.21it/s][A
 20%|█▉        | 98/4

total count: 495
Epoch 1 | Train Loss: 5.1182 | Train Accuracy: 0.0545


 40%|████      | 2/5 [00:22<00:34, 11.49s/it]

Validation Loss: 5.1099 | Validation Accuracy: 0.0545



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 59.64it/s][A
  2%|▏         | 12/495 [00:00<00:09, 50.78it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.35it/s][A
  5%|▍         | 23/495 [00:00<00:10, 47.06it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.55it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.09it/s][A
  8%|▊         | 38/495 [00:00<00:09, 47.18it/s][A
  9%|▉         | 44/495 [00:00<00:09, 48.19it/s][A
 10%|▉         | 49/495 [00:01<00:09, 47.23it/s][A
 11%|█         | 54/495 [00:01<00:09, 46.48it/s][A
 12%|█▏        | 59/495 [00:01<00:09, 46.07it/s][A
 13%|█▎        | 64/495 [00:01<00:09, 47.12it/s][A
 14%|█▍        | 69/495 [00:01<00:08, 47.73it/s][A
 15%|█▌        | 75/495 [00:01<00:08, 48.50it/s][A
 16%|█▌        | 80/495 [00:01<00:08, 47.59it/s][A
 17%|█▋        | 85/495 [00:01<00:08, 46.86it/s][A
 18%|█▊        | 90/495 [00:01<00:08, 46.32it/s][A
 19%|█▉        | 95/495 [00:02<00:08, 45.82it/s][A
 20%|██        | 100/

total count: 495
Epoch 2 | Train Loss: 5.1093 | Train Accuracy: 0.0424


 60%|██████    | 3/5 [00:34<00:23, 11.55s/it]

Validation Loss: 5.0988 | Validation Accuracy: 0.0485



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 58.93it/s][A
  2%|▏         | 12/495 [00:00<00:09, 52.69it/s][A
  4%|▎         | 18/495 [00:00<00:10, 47.60it/s][A
  5%|▍         | 23/495 [00:00<00:10, 46.37it/s][A
  6%|▌         | 28/495 [00:00<00:10, 45.84it/s][A
  7%|▋         | 33/495 [00:00<00:10, 45.89it/s][A
  8%|▊         | 38/495 [00:00<00:10, 45.54it/s][A
  9%|▊         | 43/495 [00:00<00:09, 45.41it/s][A
 10%|▉         | 48/495 [00:01<00:09, 45.46it/s][A
 11%|█         | 53/495 [00:01<00:09, 45.34it/s][A
 12%|█▏        | 58/495 [00:01<00:09, 44.99it/s][A
 13%|█▎        | 63/495 [00:01<00:09, 45.43it/s][A
 14%|█▎        | 68/495 [00:01<00:09, 44.88it/s][A
 15%|█▍        | 73/495 [00:01<00:09, 45.11it/s][A
 16%|█▌        | 78/495 [00:01<00:09, 45.29it/s][A
 17%|█▋        | 83/495 [00:01<00:09, 45.03it/s][A
 18%|█▊        | 88/495 [00:01<00:09, 45.01it/s][A
 19%|█▉        | 93/495 [00:02<00:08, 44.78it/s][A
 20%|█▉        | 98/4

total count: 495
Epoch 3 | Train Loss: 5.0985 | Train Accuracy: 0.0303


 80%|████████  | 4/5 [00:46<00:11, 11.61s/it]

Validation Loss: 5.0868 | Validation Accuracy: 0.0424



  0%|          | 0/495 [00:00<?, ?it/s][A
  1%|          | 6/495 [00:00<00:08, 54.52it/s][A
  2%|▏         | 12/495 [00:00<00:09, 51.50it/s][A
  4%|▎         | 18/495 [00:00<00:09, 48.33it/s][A
  5%|▍         | 23/495 [00:00<00:10, 46.86it/s][A
  6%|▌         | 28/495 [00:00<00:10, 46.70it/s][A
  7%|▋         | 33/495 [00:00<00:10, 46.08it/s][A
  8%|▊         | 38/495 [00:00<00:10, 45.56it/s][A
  9%|▊         | 43/495 [00:00<00:09, 46.63it/s][A
 10%|▉         | 49/495 [00:01<00:09, 48.21it/s][A
 11%|█         | 54/495 [00:01<00:09, 47.01it/s][A
 12%|█▏        | 59/495 [00:01<00:09, 46.16it/s][A
 13%|█▎        | 64/495 [00:01<00:09, 46.17it/s][A
 14%|█▍        | 69/495 [00:01<00:09, 45.89it/s][A
 15%|█▍        | 74/495 [00:01<00:09, 45.44it/s][A
 16%|█▌        | 79/495 [00:01<00:09, 45.23it/s][A
 17%|█▋        | 84/495 [00:01<00:09, 45.41it/s][A
 18%|█▊        | 89/495 [00:01<00:08, 45.59it/s][A
 19%|█▉        | 94/495 [00:02<00:08, 45.48it/s][A
 20%|██        | 99/4

total count: 495
Epoch 4 | Train Loss: 5.0870 | Train Accuracy: 0.0303


100%|██████████| 5/5 [00:57<00:00, 11.59s/it]

Validation Loss: 5.0750 | Validation Accuracy: 0.0364
labels: tensor([19], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([17], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([31], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([30], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([32], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([129], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([131], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([130], device='cuda:1') torch.Size([1])
predicted: tensor([77], device='cuda:1') torch.Size([1])
labels: tensor([133], device='cuda:1') torch.Siz




labels: tensor([91], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([58], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([9], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1])
labels: tensor([8], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1])
labels: tensor([10], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([11], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([15], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([16], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([22], device='cuda:1') torch.Size([1])
predicted: tensor([22], device='cuda:1') torch.Size([1

labels: tensor([164], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([72], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([7], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([12], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([13], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([14], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([33], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([140], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.Size([1])
labels: tensor([37], device='cuda:1') torch.Size([1])
predicted: tensor([146], device='cuda:1') torch.S

- 10 APs x 5times

In [50]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/30 [00:00<?, ?it/s][A
 20%|██        | 6/30 [00:00<00:00, 58.12it/s][A
 40%|████      | 12/30 [00:00<00:00, 49.98it/s][A
 60%|██████    | 18/30 [00:00<00:00, 48.06it/s][A
 77%|███████▋  | 23/30 [00:00<00:00, 47.01it/s][A
100%|██████████| 30/30 [00:00<00:00, 47.16it/s][A
 20%|██        | 1/5 [00:00<00:02,  1.41it/s]

total count: 30
Epoch 0 | Train Loss: 5.1210 | Train Accuracy: 0.1000
Validation Loss: 5.1164 | Validation Accuracy: 0.1000



  0%|          | 0/30 [00:00<?, ?it/s][A
 23%|██▎       | 7/30 [00:00<00:00, 59.65it/s][A
 43%|████▎     | 13/30 [00:00<00:00, 50.91it/s][A
 63%|██████▎   | 19/30 [00:00<00:00, 48.31it/s][A
 83%|████████▎ | 25/30 [00:00<00:00, 49.27it/s][A
100%|██████████| 30/30 [00:00<00:00, 48.73it/s][A
 40%|████      | 2/5 [00:01<00:02,  1.44it/s]

total count: 30
Epoch 1 | Train Loss: 5.1118 | Train Accuracy: 0.1000
Validation Loss: 5.1063 | Validation Accuracy: 0.1000



  0%|          | 0/30 [00:00<?, ?it/s][A
 20%|██        | 6/30 [00:00<00:00, 59.03it/s][A
 40%|████      | 12/30 [00:00<00:00, 51.84it/s][A
 60%|██████    | 18/30 [00:00<00:00, 48.34it/s][A
 77%|███████▋  | 23/30 [00:00<00:00, 47.06it/s][A
100%|██████████| 30/30 [00:00<00:00, 47.81it/s][A
 60%|██████    | 3/5 [00:02<00:01,  1.44it/s]

total count: 30
Epoch 2 | Train Loss: 5.1001 | Train Accuracy: 0.1000
Validation Loss: 5.0921 | Validation Accuracy: 0.1000



  0%|          | 0/30 [00:00<?, ?it/s][A
 20%|██        | 6/30 [00:00<00:00, 58.56it/s][A
 40%|████      | 12/30 [00:00<00:00, 50.13it/s][A
 60%|██████    | 18/30 [00:00<00:00, 48.19it/s][A
 77%|███████▋  | 23/30 [00:00<00:00, 47.00it/s][A
100%|██████████| 30/30 [00:00<00:00, 47.24it/s][A
 80%|████████  | 4/5 [00:02<00:00,  1.43it/s]

total count: 30
Epoch 3 | Train Loss: 5.0825 | Train Accuracy: 0.1000
Validation Loss: 5.0698 | Validation Accuracy: 0.1000



  0%|          | 0/30 [00:00<?, ?it/s][A
 20%|██        | 6/30 [00:00<00:00, 59.29it/s][A
 40%|████      | 12/30 [00:00<00:00, 50.49it/s][A
 60%|██████    | 18/30 [00:00<00:00, 48.47it/s][A
 77%|███████▋  | 23/30 [00:00<00:00, 47.15it/s][A
100%|██████████| 30/30 [00:00<00:00, 46.73it/s][A
100%|██████████| 5/5 [00:03<00:00,  1.42it/s]

total count: 30
Epoch 4 | Train Loss: 5.0550 | Train Accuracy: 0.1000
Validation Loss: 5.0359 | Validation Accuracy: 0.1000
labels: tensor([118], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([121], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([128], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([122], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([120], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([139], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([128], device='cuda:1') torch.Size([1])
labels: tensor([74], device='cuda:1') torch.Size([1])
predicted: tensor([128], device=




- 10 APs same as above x 50 times

In [54]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # label.size(0) is the batch size
        correct += (predicted == labels).sum().item() 
        # (predicted == labels).sum() would return how many of them are equal; 
        # .item() would make the tensor to the regular value
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/300 [00:00<?, ?it/s][A
  2%|▏         | 7/300 [00:00<00:04, 59.94it/s][A
  4%|▍         | 13/300 [00:00<00:05, 51.01it/s][A
  6%|▋         | 19/300 [00:00<00:05, 48.86it/s][A
  8%|▊         | 24/300 [00:00<00:05, 47.47it/s][A
 10%|▉         | 29/300 [00:00<00:05, 46.70it/s][A
 11%|█▏        | 34/300 [00:00<00:05, 46.12it/s][A
 13%|█▎        | 39/300 [00:00<00:05, 45.91it/s][A
 15%|█▍        | 44/300 [00:00<00:05, 45.59it/s][A
 16%|█▋        | 49/300 [00:01<00:05, 45.57it/s][A
 18%|█▊        | 54/300 [00:01<00:05, 45.37it/s][A
 20%|█▉        | 59/300 [00:01<00:05, 45.38it/s][A
 21%|██▏       | 64/300 [00:01<00:05, 45.33it/s][A
 23%|██▎       | 69/300 [00:01<00:05, 45.00it/s][A
 25%|██▍       | 74/300 [00:01<00:04, 45.20it/s][A
 26%|██▋       | 79/300 [00:01<00:04, 45.35it/s][A
 28%|██▊       | 84/300 [00:01<00:04, 45.28it/s][A
 30%|██▉       | 89/300 [00:01<00:04, 45.18it/s][A
 31%|███▏      | 94/300 [00:02<00:04

total count: 300
Epoch 0 | Train Loss: 4.9787 | Train Accuracy: 0.1700


 20%|██        | 1/5 [00:07<00:28,  7.03s/it]

Validation Loss: 4.6262 | Validation Accuracy: 0.1000



  0%|          | 0/300 [00:00<?, ?it/s][A
  2%|▏         | 5/300 [00:00<00:05, 49.35it/s][A
  4%|▎         | 11/300 [00:00<00:05, 49.85it/s][A
  5%|▌         | 16/300 [00:00<00:05, 47.90it/s][A
  7%|▋         | 21/300 [00:00<00:05, 46.67it/s][A
  9%|▊         | 26/300 [00:00<00:05, 46.28it/s][A
 10%|█         | 31/300 [00:00<00:05, 45.89it/s][A
 12%|█▏        | 36/300 [00:00<00:05, 45.55it/s][A
 14%|█▎        | 41/300 [00:00<00:05, 46.23it/s][A
 15%|█▌        | 46/300 [00:00<00:05, 46.45it/s][A
 17%|█▋        | 51/300 [00:01<00:05, 46.00it/s][A
 19%|█▊        | 56/300 [00:01<00:05, 45.79it/s][A
 20%|██        | 61/300 [00:01<00:05, 45.65it/s][A
 22%|██▏       | 66/300 [00:01<00:05, 45.38it/s][A
 24%|██▎       | 71/300 [00:01<00:05, 44.93it/s][A
 25%|██▌       | 76/300 [00:01<00:04, 45.38it/s][A
 27%|██▋       | 81/300 [00:01<00:04, 44.99it/s][A
 29%|██▊       | 86/300 [00:01<00:04, 45.54it/s][A
 30%|███       | 91/300 [00:01<00:04, 45.26it/s][A
 32%|███▏      | 96/3

total count: 300
Epoch 1 | Train Loss: 3.9137 | Train Accuracy: 0.1000


 40%|████      | 2/5 [00:13<00:20,  6.97s/it]

Validation Loss: 3.1964 | Validation Accuracy: 0.1000



  0%|          | 0/300 [00:00<?, ?it/s][A
  2%|▏         | 6/300 [00:00<00:05, 55.24it/s][A
  4%|▍         | 12/300 [00:00<00:05, 49.28it/s][A
  6%|▌         | 17/300 [00:00<00:05, 47.53it/s][A
  7%|▋         | 22/300 [00:00<00:05, 46.64it/s][A
  9%|▉         | 27/300 [00:00<00:05, 46.02it/s][A
 11%|█         | 32/300 [00:00<00:05, 45.81it/s][A
 12%|█▏        | 37/300 [00:00<00:05, 45.50it/s][A
 14%|█▍        | 42/300 [00:00<00:05, 45.37it/s][A
 16%|█▌        | 47/300 [00:01<00:05, 45.46it/s][A
 17%|█▋        | 52/300 [00:01<00:05, 45.26it/s][A
 19%|█▉        | 57/300 [00:01<00:05, 45.27it/s][A
 21%|██        | 62/300 [00:01<00:05, 44.35it/s][A
 22%|██▏       | 67/300 [00:01<00:05, 44.48it/s][A
 24%|██▍       | 72/300 [00:01<00:04, 45.67it/s][A
 26%|██▌       | 77/300 [00:01<00:04, 45.42it/s][A
 27%|██▋       | 82/300 [00:01<00:04, 44.84it/s][A
 29%|██▉       | 87/300 [00:01<00:04, 45.07it/s][A
 31%|███       | 92/300 [00:02<00:04, 45.61it/s][A
 32%|███▏      | 97/3

total count: 300
Epoch 2 | Train Loss: 2.8079 | Train Accuracy: 0.1033


 60%|██████    | 3/5 [00:20<00:13,  6.99s/it]

Validation Loss: 2.5638 | Validation Accuracy: 0.1000



  0%|          | 0/300 [00:00<?, ?it/s][A
  2%|▏         | 6/300 [00:00<00:05, 56.73it/s][A
  4%|▍         | 12/300 [00:00<00:05, 51.90it/s][A
  6%|▌         | 18/300 [00:00<00:05, 48.88it/s][A
  8%|▊         | 23/300 [00:00<00:05, 47.50it/s][A
  9%|▉         | 28/300 [00:00<00:05, 46.59it/s][A
 11%|█         | 33/300 [00:00<00:05, 46.24it/s][A
 13%|█▎        | 38/300 [00:00<00:05, 45.92it/s][A
 14%|█▍        | 43/300 [00:00<00:05, 45.58it/s][A
 16%|█▌        | 48/300 [00:01<00:05, 45.52it/s][A
 18%|█▊        | 53/300 [00:01<00:05, 45.42it/s][A
 19%|█▉        | 58/300 [00:01<00:05, 45.19it/s][A
 21%|██        | 63/300 [00:01<00:05, 44.70it/s][A
 23%|██▎       | 68/300 [00:01<00:05, 44.99it/s][A
 25%|██▍       | 74/300 [00:01<00:04, 47.08it/s][A
 26%|██▋       | 79/300 [00:01<00:04, 46.45it/s][A
 28%|██▊       | 84/300 [00:01<00:04, 45.89it/s][A
 30%|██▉       | 89/300 [00:01<00:04, 45.58it/s][A
 31%|███▏      | 94/300 [00:02<00:04, 45.63it/s][A
 33%|███▎      | 99/3

total count: 300
Epoch 3 | Train Loss: 2.4699 | Train Accuracy: 0.1000


 80%|████████  | 4/5 [00:27<00:06,  6.99s/it]

Validation Loss: 2.4045 | Validation Accuracy: 0.1000



  0%|          | 0/300 [00:00<?, ?it/s][A
  2%|▏         | 7/300 [00:00<00:04, 60.20it/s][A
  5%|▍         | 14/300 [00:00<00:05, 53.31it/s][A
  7%|▋         | 20/300 [00:00<00:05, 50.23it/s][A
  9%|▊         | 26/300 [00:00<00:05, 48.53it/s][A
 10%|█         | 31/300 [00:00<00:05, 47.30it/s][A
 12%|█▏        | 36/300 [00:00<00:05, 46.72it/s][A
 14%|█▎        | 41/300 [00:00<00:05, 46.13it/s][A
 15%|█▌        | 46/300 [00:00<00:05, 45.82it/s][A
 17%|█▋        | 51/300 [00:01<00:05, 45.74it/s][A
 19%|█▊        | 56/300 [00:01<00:05, 45.48it/s][A
 20%|██        | 61/300 [00:01<00:05, 45.06it/s][A
 22%|██▏       | 66/300 [00:01<00:05, 45.41it/s][A
 24%|██▎       | 71/300 [00:01<00:05, 45.48it/s][A
 25%|██▌       | 76/300 [00:01<00:04, 45.28it/s][A
 27%|██▋       | 81/300 [00:01<00:04, 45.35it/s][A
 29%|██▊       | 86/300 [00:01<00:04, 45.15it/s][A
 30%|███       | 91/300 [00:01<00:04, 45.15it/s][A
 32%|███▏      | 96/300 [00:02<00:04, 44.82it/s][A
 34%|███▎      | 101/

total count: 300
Epoch 4 | Train Loss: 2.3734 | Train Accuracy: 0.1000


100%|██████████| 5/5 [00:34<00:00,  6.99s/it]

Validation Loss: 2.3481 | Validation Accuracy: 0.1000
labels: tensor([118], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([121], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([128], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([122], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([120], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([139], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([74], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([57], device='cuda:1') torch.Size([1]




labels: tensor([83], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([74], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([57], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([0], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([118], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([121], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([128], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([122], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1])
labels: tensor([120], device='cuda:1') torch.Size([1])
predicted: tensor([74], device='cuda:1') torch.Size([1

- 10 APs same as above x 500 times

In [41]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # label.size(0) is the batch size
        correct += (predicted == labels).sum().item() 
        # (predicted == labels).sum() would return how many of them are equal; 
        # .item() would make the tensor to the regular value
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 9/5000 [00:00<01:04, 77.25it/s][A
  0%|          | 17/5000 [00:00<01:28, 56.25it/s][A
  0%|          | 23/5000 [00:00<01:35, 52.06it/s][A
  1%|          | 29/5000 [00:00<01:40, 49.56it/s][A
  1%|          | 35/5000 [00:00<01:41, 48.77it/s][A
  1%|          | 40/5000 [00:00<01:42, 48.54it/s][A
  1%|          | 45/5000 [00:00<01:44, 47.53it/s][A
  1%|          | 50/5000 [00:00<01:42, 48.17it/s][A
  1%|          | 56/5000 [00:01<01:41, 48.68it/s][A
  1%|          | 61/5000 [00:01<01:43, 47.56it/s][A
  1%|▏         | 66/5000 [00:01<01:45, 46.98it/s][A
  1%|▏         | 71/5000 [00:01<01:46, 46.44it/s][A
  2%|▏         | 76/5000 [00:01<01:47, 45.95it/s][A
  2%|▏         | 81/5000 [00:01<01:46, 46.14it/s][A
  2%|▏         | 86/5000 [00:01<01:45, 46.45it/s][A
  2%|▏         | 92/5000 [00:01<01:43, 47.28it/s][A
  2%|▏         | 97/5000 [00:02<01:44, 47.03it/s][A
  2%|▏         | 1

 31%|███       | 1526/5000 [00:33<01:15, 46.12it/s][A
 31%|███       | 1531/5000 [00:33<01:14, 46.38it/s][A
 31%|███       | 1536/5000 [00:33<01:14, 46.73it/s][A
 31%|███       | 1541/5000 [00:33<01:14, 46.21it/s][A
 31%|███       | 1546/5000 [00:33<01:15, 46.04it/s][A
 31%|███       | 1551/5000 [00:33<01:15, 45.81it/s][A
 31%|███       | 1556/5000 [00:33<01:15, 45.53it/s][A
 31%|███       | 1561/5000 [00:34<01:16, 44.98it/s][A
 31%|███▏      | 1566/5000 [00:34<01:15, 45.42it/s][A
 31%|███▏      | 1571/5000 [00:34<01:15, 45.47it/s][A
 32%|███▏      | 1576/5000 [00:34<01:15, 45.27it/s][A
 32%|███▏      | 1581/5000 [00:34<01:14, 45.72it/s][A
 32%|███▏      | 1586/5000 [00:34<01:13, 46.31it/s][A
 32%|███▏      | 1591/5000 [00:34<01:14, 45.90it/s][A
 32%|███▏      | 1596/5000 [00:34<01:14, 45.74it/s][A
 32%|███▏      | 1601/5000 [00:34<01:14, 45.39it/s][A
 32%|███▏      | 1606/5000 [00:35<01:14, 45.40it/s][A
 32%|███▏      | 1611/5000 [00:35<01:14, 45.50it/s][A
 32%|███▏ 

 60%|██████    | 3018/5000 [01:06<00:42, 46.35it/s][A
 60%|██████    | 3023/5000 [01:06<00:43, 45.93it/s][A
 61%|██████    | 3028/5000 [01:06<00:43, 45.78it/s][A
 61%|██████    | 3033/5000 [01:06<00:43, 45.45it/s][A
 61%|██████    | 3038/5000 [01:06<00:43, 45.42it/s][A
 61%|██████    | 3043/5000 [01:06<00:44, 44.26it/s][A
 61%|██████    | 3048/5000 [01:06<00:43, 44.42it/s][A
 61%|██████    | 3053/5000 [01:06<00:42, 45.94it/s][A
 61%|██████    | 3058/5000 [01:06<00:41, 46.90it/s][A
 61%|██████▏   | 3063/5000 [01:07<00:41, 46.50it/s][A
 61%|██████▏   | 3068/5000 [01:07<00:42, 45.98it/s][A
 61%|██████▏   | 3073/5000 [01:07<00:42, 45.81it/s][A
 62%|██████▏   | 3078/5000 [01:07<00:42, 45.66it/s][A
 62%|██████▏   | 3083/5000 [01:07<00:42, 45.49it/s][A
 62%|██████▏   | 3088/5000 [01:07<00:42, 44.93it/s][A
 62%|██████▏   | 3093/5000 [01:07<00:42, 45.31it/s][A
 62%|██████▏   | 3098/5000 [01:07<00:41, 45.30it/s][A
 62%|██████▏   | 3103/5000 [01:07<00:41, 45.25it/s][A
 62%|█████

 90%|█████████ | 4508/5000 [01:38<00:10, 46.20it/s][A
 90%|█████████ | 4513/5000 [01:38<00:10, 46.11it/s][A
 90%|█████████ | 4518/5000 [01:38<00:10, 45.81it/s][A
 90%|█████████ | 4523/5000 [01:39<00:10, 45.64it/s][A
 91%|█████████ | 4528/5000 [01:39<00:10, 45.08it/s][A
 91%|█████████ | 4533/5000 [01:39<00:10, 45.01it/s][A
 91%|█████████ | 4538/5000 [01:39<00:10, 45.37it/s][A
 91%|█████████ | 4543/5000 [01:39<00:10, 45.39it/s][A
 91%|█████████ | 4548/5000 [01:39<00:09, 45.40it/s][A
 91%|█████████ | 4553/5000 [01:39<00:09, 45.38it/s][A
 91%|█████████ | 4558/5000 [01:39<00:09, 45.30it/s][A
 91%|█████████▏| 4563/5000 [01:39<00:09, 45.16it/s][A
 91%|█████████▏| 4568/5000 [01:40<00:09, 46.48it/s][A
 91%|█████████▏| 4574/5000 [01:40<00:08, 47.59it/s][A
 92%|█████████▏| 4579/5000 [01:40<00:08, 46.86it/s][A
 92%|█████████▏| 4584/5000 [01:40<00:08, 46.36it/s][A
 92%|█████████▏| 4589/5000 [01:40<00:08, 45.88it/s][A
 92%|█████████▏| 4594/5000 [01:40<00:08, 45.64it/s][A
 92%|█████

total count: 5000
Epoch 0 | Train Loss: 2.6194 | Train Accuracy: 0.1024
Validation Loss: 2.2653 | Validation Accuracy: 0.1000



  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 5/5000 [00:00<01:44, 47.62it/s][A
  0%|          | 11/5000 [00:00<01:40, 49.49it/s][A
  0%|          | 16/5000 [00:00<01:44, 47.57it/s][A
  0%|          | 21/5000 [00:00<01:47, 46.51it/s][A
  1%|          | 26/5000 [00:00<01:48, 45.81it/s][A
  1%|          | 31/5000 [00:00<01:48, 45.87it/s][A
  1%|          | 36/5000 [00:00<01:48, 45.69it/s][A
  1%|          | 41/5000 [00:00<01:49, 45.36it/s][A
  1%|          | 46/5000 [00:01<01:49, 45.20it/s][A
  1%|          | 51/5000 [00:01<01:49, 45.16it/s][A
  1%|          | 56/5000 [00:01<01:48, 45.44it/s][A
  1%|          | 61/5000 [00:01<01:48, 45.35it/s][A
  1%|▏         | 66/5000 [00:01<01:46, 46.54it/s][A
  1%|▏         | 71/5000 [00:01<01:43, 47.43it/s][A
  2%|▏         | 77/5000 [00:01<01:41, 48.48it/s][A
  2%|▏         | 82/5000 [00:01<01:41, 48.66it/s][A
  2%|▏         | 87/5000 [00:01<01:43, 47.54it/s][A
  2%|▏         | 92/5000 [00:01<01:44, 46.78it/s][A
  

 30%|███       | 1519/5000 [00:33<01:16, 45.27it/s][A
 30%|███       | 1524/5000 [00:33<01:16, 45.25it/s][A
 31%|███       | 1529/5000 [00:33<01:17, 44.72it/s][A
 31%|███       | 1534/5000 [00:33<01:16, 45.35it/s][A
 31%|███       | 1539/5000 [00:33<01:15, 45.82it/s][A
 31%|███       | 1544/5000 [00:33<01:14, 46.63it/s][A
 31%|███       | 1549/5000 [00:33<01:14, 46.07it/s][A
 31%|███       | 1554/5000 [00:33<01:15, 45.73it/s][A
 31%|███       | 1559/5000 [00:33<01:15, 45.63it/s][A
 31%|███▏      | 1564/5000 [00:34<01:15, 45.37it/s][A
 31%|███▏      | 1569/5000 [00:34<01:13, 46.39it/s][A
 32%|███▏      | 1575/5000 [00:34<01:11, 47.68it/s][A
 32%|███▏      | 1580/5000 [00:34<01:11, 47.96it/s][A
 32%|███▏      | 1586/5000 [00:34<01:09, 48.86it/s][A
 32%|███▏      | 1591/5000 [00:34<01:11, 47.53it/s][A
 32%|███▏      | 1596/5000 [00:34<01:12, 46.84it/s][A
 32%|███▏      | 1601/5000 [00:34<01:13, 46.55it/s][A
 32%|███▏      | 1606/5000 [00:34<01:13, 46.16it/s][A
 32%|███▏ 

 60%|██████    | 3006/5000 [01:05<00:43, 45.93it/s][A
 60%|██████    | 3011/5000 [01:05<00:43, 45.41it/s][A
 60%|██████    | 3016/5000 [01:05<00:43, 45.63it/s][A
 60%|██████    | 3021/5000 [01:05<00:43, 45.90it/s][A
 61%|██████    | 3026/5000 [01:06<00:43, 45.33it/s][A
 61%|██████    | 3031/5000 [01:06<00:43, 45.13it/s][A
 61%|██████    | 3036/5000 [01:06<00:43, 45.16it/s][A
 61%|██████    | 3041/5000 [01:06<00:43, 45.36it/s][A
 61%|██████    | 3046/5000 [01:06<00:44, 44.12it/s][A
 61%|██████    | 3051/5000 [01:06<00:43, 44.30it/s][A
 61%|██████    | 3056/5000 [01:06<00:43, 44.59it/s][A
 61%|██████    | 3061/5000 [01:06<00:43, 44.79it/s][A
 61%|██████▏   | 3066/5000 [01:06<00:42, 45.01it/s][A
 61%|██████▏   | 3071/5000 [01:07<00:42, 45.25it/s][A
 62%|██████▏   | 3076/5000 [01:07<00:41, 46.06it/s][A
 62%|██████▏   | 3081/5000 [01:07<00:41, 45.92it/s][A
 62%|██████▏   | 3086/5000 [01:07<00:41, 45.59it/s][A
 62%|██████▏   | 3091/5000 [01:07<00:42, 45.45it/s][A
 62%|█████

 90%|████████▉ | 4499/5000 [01:38<00:11, 45.38it/s][A
 90%|█████████ | 4504/5000 [01:38<00:10, 45.56it/s][A
 90%|█████████ | 4509/5000 [01:38<00:10, 45.35it/s][A
 90%|█████████ | 4514/5000 [01:38<00:10, 45.23it/s][A
 90%|█████████ | 4519/5000 [01:38<00:10, 44.66it/s][A
 90%|█████████ | 4524/5000 [01:38<00:10, 45.23it/s][A
 91%|█████████ | 4529/5000 [01:38<00:10, 45.25it/s][A
 91%|█████████ | 4534/5000 [01:39<00:10, 45.29it/s][A
 91%|█████████ | 4539/5000 [01:39<00:10, 45.33it/s][A
 91%|█████████ | 4544/5000 [01:39<00:09, 46.02it/s][A
 91%|█████████ | 4549/5000 [01:39<00:09, 46.34it/s][A
 91%|█████████ | 4554/5000 [01:39<00:09, 45.94it/s][A
 91%|█████████ | 4559/5000 [01:39<00:09, 45.70it/s][A
 91%|█████████▏| 4564/5000 [01:39<00:09, 45.56it/s][A
 91%|█████████▏| 4569/5000 [01:39<00:09, 46.15it/s][A
 91%|█████████▏| 4574/5000 [01:39<00:09, 46.43it/s][A
 92%|█████████▏| 4579/5000 [01:40<00:09, 45.99it/s][A
 92%|█████████▏| 4584/5000 [01:40<00:09, 45.81it/s][A
 92%|█████

total count: 5000
Epoch 1 | Train Loss: 2.2353 | Train Accuracy: 0.1022
Validation Loss: 2.1931 | Validation Accuracy: 0.1000



  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 6/5000 [00:00<01:26, 57.58it/s][A
  0%|          | 12/5000 [00:00<01:36, 51.49it/s][A
  0%|          | 18/5000 [00:00<01:41, 48.94it/s][A
  0%|          | 23/5000 [00:00<01:41, 48.93it/s][A
  1%|          | 29/5000 [00:00<01:40, 49.43it/s][A
  1%|          | 34/5000 [00:00<01:40, 49.26it/s][A
  1%|          | 40/5000 [00:00<01:40, 49.49it/s][A
  1%|          | 45/5000 [00:00<01:42, 48.40it/s][A
  1%|          | 50/5000 [00:01<01:42, 48.11it/s][A
  1%|          | 56/5000 [00:01<01:40, 49.06it/s][A
  1%|          | 61/5000 [00:01<01:43, 47.85it/s][A
  1%|▏         | 66/5000 [00:01<01:44, 47.36it/s][A
  1%|▏         | 71/5000 [00:01<01:44, 47.39it/s][A
  2%|▏         | 76/5000 [00:01<01:44, 47.05it/s][A
  2%|▏         | 81/5000 [00:01<01:43, 47.59it/s][A
  2%|▏         | 86/5000 [00:01<01:41, 48.28it/s][A
  2%|▏         | 91/5000 [00:01<01:43, 47.65it/s][A
  2%|▏         | 96/5000 [00:01<01:42, 47.92it/s][A
  

 31%|███       | 1550/5000 [00:33<01:12, 47.40it/s][A
 31%|███       | 1555/5000 [00:33<01:13, 46.81it/s][A
 31%|███       | 1560/5000 [00:33<01:14, 46.48it/s][A
 31%|███▏      | 1565/5000 [00:33<01:13, 46.85it/s][A
 31%|███▏      | 1570/5000 [00:33<01:13, 46.56it/s][A
 32%|███▏      | 1575/5000 [00:33<01:14, 46.24it/s][A
 32%|███▏      | 1580/5000 [00:33<01:13, 46.52it/s][A
 32%|███▏      | 1585/5000 [00:33<01:12, 47.36it/s][A
 32%|███▏      | 1590/5000 [00:33<01:13, 46.56it/s][A
 32%|███▏      | 1595/5000 [00:33<01:12, 46.74it/s][A
 32%|███▏      | 1600/5000 [00:34<01:13, 46.36it/s][A
 32%|███▏      | 1605/5000 [00:34<01:11, 47.30it/s][A
 32%|███▏      | 1610/5000 [00:34<01:10, 47.96it/s][A
 32%|███▏      | 1615/5000 [00:34<01:11, 47.16it/s][A
 32%|███▏      | 1620/5000 [00:34<01:12, 46.56it/s][A
 32%|███▎      | 1625/5000 [00:34<01:11, 47.38it/s][A
 33%|███▎      | 1630/5000 [00:34<01:10, 48.09it/s][A
 33%|███▎      | 1635/5000 [00:34<01:11, 47.15it/s][A
 33%|███▎ 

 61%|██████▏   | 3069/5000 [01:05<00:42, 45.07it/s][A
 61%|██████▏   | 3074/5000 [01:05<00:42, 45.44it/s][A
 62%|██████▏   | 3079/5000 [01:05<00:42, 45.55it/s][A
 62%|██████▏   | 3084/5000 [01:06<00:42, 44.89it/s][A
 62%|██████▏   | 3089/5000 [01:06<00:42, 44.85it/s][A
 62%|██████▏   | 3094/5000 [01:06<00:42, 45.17it/s][A
 62%|██████▏   | 3099/5000 [01:06<00:41, 45.56it/s][A
 62%|██████▏   | 3104/5000 [01:06<00:42, 45.10it/s][A
 62%|██████▏   | 3109/5000 [01:06<00:42, 44.96it/s][A
 62%|██████▏   | 3114/5000 [01:06<00:41, 45.40it/s][A
 62%|██████▏   | 3119/5000 [01:06<00:41, 45.08it/s][A
 62%|██████▏   | 3124/5000 [01:06<00:41, 45.11it/s][A
 63%|██████▎   | 3129/5000 [01:07<00:40, 46.34it/s][A
 63%|██████▎   | 3134/5000 [01:07<00:39, 46.87it/s][A
 63%|██████▎   | 3139/5000 [01:07<00:39, 47.02it/s][A
 63%|██████▎   | 3144/5000 [01:07<00:40, 45.77it/s][A
 63%|██████▎   | 3149/5000 [01:07<00:39, 46.43it/s][A
 63%|██████▎   | 3154/5000 [01:07<00:39, 46.54it/s][A
 63%|█████

total count: 5000
Epoch 2 | Train Loss: 2.1262 | Train Accuracy: 0.2748
Validation Loss: 2.0385 | Validation Accuracy: 0.6000



  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 6/5000 [00:00<01:23, 59.93it/s][A
  0%|          | 12/5000 [00:00<01:35, 52.48it/s][A
  0%|          | 18/5000 [00:00<01:41, 49.15it/s][A
  0%|          | 23/5000 [00:00<01:43, 48.07it/s][A
  1%|          | 28/5000 [00:00<01:43, 47.88it/s][A
  1%|          | 34/5000 [00:00<01:41, 49.14it/s][A
  1%|          | 39/5000 [00:00<01:44, 47.49it/s][A
  1%|          | 44/5000 [00:00<01:45, 47.16it/s][A
  1%|          | 49/5000 [00:01<01:45, 47.00it/s][A
  1%|          | 54/5000 [00:01<01:44, 47.12it/s][A
  1%|          | 59/5000 [00:01<01:45, 46.70it/s][A
  1%|▏         | 64/5000 [00:01<01:46, 46.23it/s][A
  1%|▏         | 69/5000 [00:01<01:47, 45.93it/s][A
  1%|▏         | 74/5000 [00:01<01:47, 45.99it/s][A
  2%|▏         | 79/5000 [00:01<01:45, 46.73it/s][A
  2%|▏         | 84/5000 [00:01<01:45, 46.49it/s][A
  2%|▏         | 89/5000 [00:01<01:44, 47.22it/s][A
  2%|▏         | 94/5000 [00:01<01:45, 46.55it/s][A
  

 31%|███       | 1534/5000 [00:33<01:15, 45.86it/s][A
 31%|███       | 1539/5000 [00:33<01:15, 45.59it/s][A
 31%|███       | 1544/5000 [00:33<01:15, 45.88it/s][A
 31%|███       | 1549/5000 [00:33<01:15, 45.99it/s][A
 31%|███       | 1554/5000 [00:33<01:14, 46.34it/s][A
 31%|███       | 1559/5000 [00:33<01:13, 47.03it/s][A
 31%|███▏      | 1565/5000 [00:33<01:11, 48.21it/s][A
 31%|███▏      | 1570/5000 [00:33<01:12, 47.13it/s][A
 32%|███▏      | 1575/5000 [00:33<01:13, 46.67it/s][A
 32%|███▏      | 1580/5000 [00:34<01:14, 46.10it/s][A
 32%|███▏      | 1585/5000 [00:34<01:14, 45.96it/s][A
 32%|███▏      | 1590/5000 [00:34<01:14, 45.58it/s][A
 32%|███▏      | 1595/5000 [00:34<01:14, 45.93it/s][A
 32%|███▏      | 1600/5000 [00:34<01:13, 45.99it/s][A
 32%|███▏      | 1605/5000 [00:34<01:13, 46.34it/s][A
 32%|███▏      | 1610/5000 [00:34<01:11, 47.21it/s][A
 32%|███▏      | 1615/5000 [00:34<01:14, 45.39it/s][A
 32%|███▏      | 1620/5000 [00:34<01:14, 45.24it/s][A
 32%|███▎ 

 61%|██████    | 3027/5000 [01:05<00:42, 46.48it/s][A
 61%|██████    | 3032/5000 [01:05<00:42, 45.96it/s][A
 61%|██████    | 3037/5000 [01:05<00:42, 45.86it/s][A
 61%|██████    | 3042/5000 [01:05<00:42, 45.62it/s][A
 61%|██████    | 3047/5000 [01:06<00:42, 45.44it/s][A
 61%|██████    | 3052/5000 [01:06<00:43, 44.83it/s][A
 61%|██████    | 3057/5000 [01:06<00:43, 45.08it/s][A
 61%|██████    | 3062/5000 [01:06<00:42, 45.43it/s][A
 61%|██████▏   | 3067/5000 [01:06<00:41, 46.03it/s][A
 61%|██████▏   | 3072/5000 [01:06<00:41, 46.33it/s][A
 62%|██████▏   | 3077/5000 [01:06<00:42, 45.56it/s][A
 62%|██████▏   | 3082/5000 [01:06<00:42, 45.57it/s][A
 62%|██████▏   | 3087/5000 [01:06<00:41, 45.73it/s][A
 62%|██████▏   | 3092/5000 [01:07<00:41, 45.75it/s][A
 62%|██████▏   | 3097/5000 [01:07<00:41, 45.45it/s][A
 62%|██████▏   | 3102/5000 [01:07<00:41, 45.37it/s][A
 62%|██████▏   | 3107/5000 [01:07<00:41, 45.27it/s][A
 62%|██████▏   | 3112/5000 [01:07<00:41, 45.69it/s][A
 62%|█████

 90%|█████████ | 4517/5000 [01:38<00:10, 46.23it/s][A
 90%|█████████ | 4522/5000 [01:38<00:10, 45.87it/s][A
 91%|█████████ | 4527/5000 [01:38<00:10, 45.64it/s][A
 91%|█████████ | 4532/5000 [01:38<00:10, 45.51it/s][A
 91%|█████████ | 4538/5000 [01:38<00:09, 47.16it/s][A
 91%|█████████ | 4543/5000 [01:38<00:09, 47.71it/s][A
 91%|█████████ | 4548/5000 [01:38<00:09, 46.80it/s][A
 91%|█████████ | 4553/5000 [01:39<00:09, 46.39it/s][A
 91%|█████████ | 4558/5000 [01:39<00:09, 45.91it/s][A
 91%|█████████▏| 4563/5000 [01:39<00:09, 45.68it/s][A
 91%|█████████▏| 4568/5000 [01:39<00:09, 45.67it/s][A
 91%|█████████▏| 4573/5000 [01:39<00:09, 45.41it/s][A
 92%|█████████▏| 4578/5000 [01:39<00:09, 45.33it/s][A
 92%|█████████▏| 4583/5000 [01:39<00:09, 46.28it/s][A
 92%|█████████▏| 4589/5000 [01:39<00:08, 47.14it/s][A
 92%|█████████▏| 4594/5000 [01:39<00:08, 46.80it/s][A
 92%|█████████▏| 4600/5000 [01:40<00:08, 47.99it/s][A
 92%|█████████▏| 4605/5000 [01:40<00:08, 47.14it/s][A
 92%|█████

total count: 5000
Epoch 3 | Train Loss: 1.9214 | Train Accuracy: 0.5590
Validation Loss: 1.7800 | Validation Accuracy: 0.6000



  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 6/5000 [00:00<01:24, 59.02it/s][A
  0%|          | 12/5000 [00:00<01:39, 50.26it/s][A
  0%|          | 18/5000 [00:00<01:43, 47.98it/s][A
  0%|          | 23/5000 [00:00<01:45, 46.96it/s][A
  1%|          | 28/5000 [00:00<01:45, 47.20it/s][A
  1%|          | 33/5000 [00:00<01:45, 47.10it/s][A
  1%|          | 38/5000 [00:00<01:47, 46.09it/s][A
  1%|          | 43/5000 [00:00<01:48, 45.79it/s][A
  1%|          | 48/5000 [00:01<01:48, 45.67it/s][A
  1%|          | 53/5000 [00:01<01:47, 45.86it/s][A
  1%|          | 58/5000 [00:01<01:46, 46.40it/s][A
  1%|▏         | 63/5000 [00:01<01:45, 46.60it/s][A
  1%|▏         | 68/5000 [00:01<01:49, 45.25it/s][A
  1%|▏         | 73/5000 [00:01<01:47, 45.93it/s][A
  2%|▏         | 78/5000 [00:01<01:47, 45.83it/s][A
  2%|▏         | 83/5000 [00:01<01:47, 45.79it/s][A
  2%|▏         | 88/5000 [00:01<01:47, 45.61it/s][A
  2%|▏         | 93/5000 [00:02<01:48, 45.41it/s][A
  

 31%|███       | 1561/5000 [00:32<01:13, 46.96it/s][A
 31%|███▏      | 1567/5000 [00:32<01:11, 48.28it/s][A
 31%|███▏      | 1572/5000 [00:33<01:12, 47.40it/s][A
 32%|███▏      | 1577/5000 [00:33<01:13, 46.69it/s][A
 32%|███▏      | 1582/5000 [00:33<01:13, 46.35it/s][A
 32%|███▏      | 1587/5000 [00:33<01:13, 46.70it/s][A
 32%|███▏      | 1593/5000 [00:33<01:10, 48.19it/s][A
 32%|███▏      | 1598/5000 [00:33<01:12, 47.23it/s][A
 32%|███▏      | 1603/5000 [00:33<01:12, 46.93it/s][A
 32%|███▏      | 1608/5000 [00:33<01:11, 47.40it/s][A
 32%|███▏      | 1613/5000 [00:33<01:12, 47.00it/s][A
 32%|███▏      | 1618/5000 [00:34<01:11, 47.45it/s][A
 32%|███▏      | 1623/5000 [00:34<01:12, 46.75it/s][A
 33%|███▎      | 1628/5000 [00:34<01:12, 46.27it/s][A
 33%|███▎      | 1633/5000 [00:34<01:13, 45.97it/s][A
 33%|███▎      | 1638/5000 [00:34<01:13, 45.95it/s][A
 33%|███▎      | 1643/5000 [00:34<01:12, 46.38it/s][A
 33%|███▎      | 1649/5000 [00:34<01:10, 47.85it/s][A
 33%|███▎ 

 62%|██████▏   | 3104/5000 [01:05<00:38, 49.30it/s][A
 62%|██████▏   | 3109/5000 [01:05<00:39, 47.44it/s][A
 62%|██████▏   | 3114/5000 [01:05<00:39, 47.44it/s][A
 62%|██████▏   | 3119/5000 [01:05<00:40, 46.83it/s][A
 62%|██████▏   | 3124/5000 [01:05<00:40, 46.58it/s][A
 63%|██████▎   | 3129/5000 [01:05<00:40, 46.63it/s][A
 63%|██████▎   | 3134/5000 [01:05<00:39, 46.70it/s][A
 63%|██████▎   | 3139/5000 [01:05<00:40, 46.45it/s][A
 63%|██████▎   | 3144/5000 [01:06<00:39, 47.05it/s][A
 63%|██████▎   | 3150/5000 [01:06<00:38, 48.12it/s][A
 63%|██████▎   | 3155/5000 [01:06<00:38, 47.92it/s][A
 63%|██████▎   | 3161/5000 [01:06<00:37, 48.80it/s][A
 63%|██████▎   | 3166/5000 [01:06<00:38, 47.92it/s][A
 63%|██████▎   | 3171/5000 [01:06<00:38, 47.39it/s][A
 64%|██████▎   | 3176/5000 [01:06<00:38, 47.27it/s][A
 64%|██████▎   | 3182/5000 [01:06<00:37, 48.72it/s][A
 64%|██████▎   | 3187/5000 [01:06<00:38, 47.65it/s][A
 64%|██████▍   | 3192/5000 [01:07<00:37, 47.86it/s][A
 64%|█████

 92%|█████████▏| 4610/5000 [01:37<00:08, 46.19it/s][A
 92%|█████████▏| 4615/5000 [01:37<00:08, 46.10it/s][A
 92%|█████████▏| 4620/5000 [01:38<00:08, 46.06it/s][A
 92%|█████████▎| 4625/5000 [01:38<00:08, 45.82it/s][A
 93%|█████████▎| 4630/5000 [01:38<00:08, 45.19it/s][A
 93%|█████████▎| 4635/5000 [01:38<00:08, 45.22it/s][A
 93%|█████████▎| 4640/5000 [01:38<00:07, 45.48it/s][A
 93%|█████████▎| 4645/5000 [01:38<00:07, 46.43it/s][A
 93%|█████████▎| 4651/5000 [01:38<00:07, 47.54it/s][A
 93%|█████████▎| 4656/5000 [01:38<00:07, 47.13it/s][A
 93%|█████████▎| 4661/5000 [01:38<00:07, 46.44it/s][A
 93%|█████████▎| 4666/5000 [01:39<00:07, 46.03it/s][A
 93%|█████████▎| 4671/5000 [01:39<00:07, 45.79it/s][A
 94%|█████████▎| 4676/5000 [01:39<00:06, 46.55it/s][A
 94%|█████████▎| 4681/5000 [01:39<00:06, 46.61it/s][A
 94%|█████████▎| 4686/5000 [01:39<00:06, 46.15it/s][A
 94%|█████████▍| 4691/5000 [01:39<00:06, 45.88it/s][A
 94%|█████████▍| 4696/5000 [01:39<00:06, 45.59it/s][A
 94%|█████

total count: 5000
Epoch 4 | Train Loss: 1.6244 | Train Accuracy: 0.6692
Validation Loss: 1.4473 | Validation Accuracy: 0.7000
labels: tensor([118], device='cuda:0') torch.Size([1])
predicted: tensor([118], device='cuda:0') torch.Size([1])
labels: tensor([121], device='cuda:0') torch.Size([1])
predicted: tensor([121], device='cuda:0') torch.Size([1])
labels: tensor([128], device='cuda:0') torch.Size([1])
predicted: tensor([128], device='cuda:0') torch.Size([1])
labels: tensor([122], device='cuda:0') torch.Size([1])
predicted: tensor([122], device='cuda:0') torch.Size([1])
labels: tensor([120], device='cuda:0') torch.Size([1])
predicted: tensor([120], device='cuda:0') torch.Size([1])
labels: tensor([139], device='cuda:0') torch.Size([1])
predicted: tensor([83], device='cuda:0') torch.Size([1])
labels: tensor([83], device='cuda:0') torch.Size([1])
predicted: tensor([83], device='cuda:0') torch.Size([1])
labels: tensor([74], device='cuda:0') torch.Size([1])
predicted: tensor([74], device='




- 10 APs same as above x 500 times and batch size = 2

In [43]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train']):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # label.size(0) is the batch size
        correct += (predicted == labels).sum().item() 
        # (predicted == labels).sum() would return how many of them are equal; 
        # .item() would make the tensor to the regular value
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 7/2500 [00:00<00:40, 61.20it/s][A
  1%|          | 14/2500 [00:00<00:45, 54.45it/s][A
  1%|          | 20/2500 [00:00<00:48, 50.80it/s][A
  1%|          | 26/2500 [00:00<00:49, 50.15it/s][A
  1%|▏         | 32/2500 [00:00<00:49, 49.56it/s][A
  2%|▏         | 38/2500 [00:00<00:49, 50.08it/s][A
  2%|▏         | 44/2500 [00:00<00:49, 49.15it/s][A
  2%|▏         | 49/2500 [00:00<00:50, 48.95it/s][A
  2%|▏         | 54/2500 [00:01<00:49, 48.94it/s][A
  2%|▏         | 59/2500 [00:01<00:49, 48.86it/s][A
  3%|▎         | 65/2500 [00:01<00:49, 49.49it/s][A
  3%|▎         | 70/2500 [00:01<00:49, 49.41it/s][A
  3%|▎         | 75/2500 [00:01<00:49, 48.91it/s][A
  3%|▎         | 81/2500 [00:01<00:48, 49.85it/s][A
  3%|▎         | 86/2500 [00:01<00:49, 49.11it/s][A
  4%|▎         | 91/2500 [00:01<00:49, 48.76it/s][A
  4%|▍         | 96/2500 [00:01<00:49, 48.65it/s][A
  4%|▍         | 1

 64%|██████▍   | 1595/2500 [00:32<00:18, 47.97it/s][A
 64%|██████▍   | 1600/2500 [00:32<00:19, 47.16it/s][A
 64%|██████▍   | 1605/2500 [00:32<00:19, 46.89it/s][A
 64%|██████▍   | 1610/2500 [00:33<00:18, 47.28it/s][A
 65%|██████▍   | 1616/2500 [00:33<00:18, 48.52it/s][A
 65%|██████▍   | 1621/2500 [00:33<00:18, 48.36it/s][A
 65%|██████▌   | 1627/2500 [00:33<00:17, 49.17it/s][A
 65%|██████▌   | 1632/2500 [00:33<00:17, 49.20it/s][A
 65%|██████▌   | 1637/2500 [00:33<00:17, 49.40it/s][A
 66%|██████▌   | 1642/2500 [00:33<00:17, 49.37it/s][A
 66%|██████▌   | 1647/2500 [00:33<00:17, 49.55it/s][A
 66%|██████▌   | 1652/2500 [00:33<00:17, 49.03it/s][A
 66%|██████▋   | 1658/2500 [00:34<00:16, 49.86it/s][A
 67%|██████▋   | 1663/2500 [00:34<00:17, 48.38it/s][A
 67%|██████▋   | 1668/2500 [00:34<00:17, 48.44it/s][A
 67%|██████▋   | 1674/2500 [00:34<00:16, 49.10it/s][A
 67%|██████▋   | 1679/2500 [00:34<00:17, 47.97it/s][A
 67%|██████▋   | 1684/2500 [00:34<00:17, 47.89it/s][A
 68%|█████

total count: 2500
Epoch 0 | Train Loss: 2.8673 | Train Accuracy: 0.1054
Validation Loss: 2.2894 | Validation Accuracy: 0.1000



  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 6/2500 [00:00<00:42, 58.46it/s][A
  0%|          | 12/2500 [00:00<00:46, 53.48it/s][A
  1%|          | 18/2500 [00:00<00:49, 49.68it/s][A
  1%|          | 24/2500 [00:00<00:51, 48.10it/s][A
  1%|          | 29/2500 [00:00<00:52, 47.11it/s][A
  1%|▏         | 34/2500 [00:00<00:53, 46.35it/s][A
  2%|▏         | 39/2500 [00:00<00:52, 46.94it/s][A
  2%|▏         | 45/2500 [00:00<00:50, 48.26it/s][A
  2%|▏         | 50/2500 [00:01<00:50, 48.59it/s][A
  2%|▏         | 56/2500 [00:01<00:49, 49.08it/s][A
  2%|▏         | 61/2500 [00:01<00:50, 47.85it/s][A
  3%|▎         | 66/2500 [00:01<00:50, 48.28it/s][A
  3%|▎         | 71/2500 [00:01<00:49, 48.75it/s][A
  3%|▎         | 76/2500 [00:01<00:50, 47.71it/s][A
  3%|▎         | 81/2500 [00:01<00:51, 46.72it/s][A
  3%|▎         | 86/2500 [00:01<00:51, 46.48it/s][A
  4%|▎         | 91/2500 [00:01<00:50, 47.28it/s][A
  4%|▍         | 96/2500 [00:01<00:50, 47.92it/s][A
  

 63%|██████▎   | 1581/2500 [00:32<00:19, 48.37it/s][A
 63%|██████▎   | 1587/2500 [00:32<00:18, 48.96it/s][A
 64%|██████▎   | 1592/2500 [00:33<00:18, 48.96it/s][A
 64%|██████▍   | 1598/2500 [00:33<00:18, 49.29it/s][A
 64%|██████▍   | 1603/2500 [00:33<00:18, 48.13it/s][A
 64%|██████▍   | 1608/2500 [00:33<00:18, 47.21it/s][A
 65%|██████▍   | 1613/2500 [00:33<00:18, 46.92it/s][A
 65%|██████▍   | 1618/2500 [00:33<00:18, 47.18it/s][A
 65%|██████▍   | 1623/2500 [00:33<00:18, 46.50it/s][A
 65%|██████▌   | 1628/2500 [00:33<00:18, 46.34it/s][A
 65%|██████▌   | 1633/2500 [00:33<00:18, 46.84it/s][A
 66%|██████▌   | 1639/2500 [00:34<00:17, 48.19it/s][A
 66%|██████▌   | 1644/2500 [00:34<00:17, 48.28it/s][A
 66%|██████▌   | 1650/2500 [00:34<00:17, 49.05it/s][A
 66%|██████▌   | 1655/2500 [00:34<00:17, 48.82it/s][A
 66%|██████▋   | 1661/2500 [00:34<00:16, 49.49it/s][A
 67%|██████▋   | 1666/2500 [00:34<00:17, 48.24it/s][A
 67%|██████▋   | 1671/2500 [00:34<00:17, 48.26it/s][A
 67%|█████

total count: 2500
Epoch 1 | Train Loss: 2.2693 | Train Accuracy: 0.1000
Validation Loss: 2.2475 | Validation Accuracy: 0.1000



  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 6/2500 [00:00<00:45, 54.36it/s][A
  0%|          | 12/2500 [00:00<00:50, 48.96it/s][A
  1%|          | 17/2500 [00:00<00:50, 49.20it/s][A
  1%|          | 22/2500 [00:00<00:50, 49.48it/s][A
  1%|          | 27/2500 [00:00<00:51, 47.88it/s][A
  1%|▏         | 32/2500 [00:00<00:51, 48.33it/s][A
  1%|▏         | 37/2500 [00:00<00:50, 48.80it/s][A
  2%|▏         | 42/2500 [00:00<00:51, 47.59it/s][A
  2%|▏         | 47/2500 [00:00<00:52, 46.86it/s][A
  2%|▏         | 52/2500 [00:01<00:51, 47.52it/s][A
  2%|▏         | 57/2500 [00:01<00:50, 48.13it/s][A
  2%|▏         | 62/2500 [00:01<00:51, 47.41it/s][A
  3%|▎         | 67/2500 [00:01<00:50, 47.99it/s][A
  3%|▎         | 72/2500 [00:01<00:50, 48.50it/s][A
  3%|▎         | 77/2500 [00:01<00:50, 48.42it/s][A
  3%|▎         | 83/2500 [00:01<00:49, 49.10it/s][A
  4%|▎         | 88/2500 [00:01<00:50, 48.05it/s][A
  4%|▎         | 93/2500 [00:01<00:49, 48.37it/s][A
  

 63%|██████▎   | 1578/2500 [00:32<00:18, 49.25it/s][A
 63%|██████▎   | 1584/2500 [00:32<00:18, 49.77it/s][A
 64%|██████▎   | 1589/2500 [00:32<00:18, 48.67it/s][A
 64%|██████▍   | 1594/2500 [00:32<00:18, 48.76it/s][A
 64%|██████▍   | 1599/2500 [00:33<00:18, 48.91it/s][A
 64%|██████▍   | 1605/2500 [00:33<00:18, 49.25it/s][A
 64%|██████▍   | 1610/2500 [00:33<00:18, 49.26it/s][A
 65%|██████▍   | 1616/2500 [00:33<00:17, 49.49it/s][A
 65%|██████▍   | 1621/2500 [00:33<00:17, 49.42it/s][A
 65%|██████▌   | 1627/2500 [00:33<00:17, 49.91it/s][A
 65%|██████▌   | 1632/2500 [00:33<00:17, 49.36it/s][A
 65%|██████▌   | 1637/2500 [00:33<00:17, 48.32it/s][A
 66%|██████▌   | 1642/2500 [00:33<00:17, 48.71it/s][A
 66%|██████▌   | 1648/2500 [00:34<00:17, 49.07it/s][A
 66%|██████▌   | 1653/2500 [00:34<00:17, 49.18it/s][A
 66%|██████▋   | 1658/2500 [00:34<00:17, 49.29it/s][A
 67%|██████▋   | 1663/2500 [00:34<00:17, 49.02it/s][A
 67%|██████▋   | 1669/2500 [00:34<00:16, 49.57it/s][A
 67%|█████

total count: 2500
Epoch 2 | Train Loss: 2.2223 | Train Accuracy: 0.1018
Validation Loss: 2.1883 | Validation Accuracy: 0.1000



  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 6/2500 [00:00<00:46, 53.49it/s][A
  0%|          | 12/2500 [00:00<00:51, 48.69it/s][A
  1%|          | 17/2500 [00:00<00:53, 46.71it/s][A
  1%|          | 22/2500 [00:00<00:53, 46.28it/s][A
  1%|          | 27/2500 [00:00<00:53, 46.07it/s][A
  1%|▏         | 32/2500 [00:00<00:53, 45.78it/s][A
  1%|▏         | 37/2500 [00:00<00:54, 45.29it/s][A
  2%|▏         | 42/2500 [00:00<00:54, 45.31it/s][A
  2%|▏         | 47/2500 [00:01<00:53, 45.50it/s][A
  2%|▏         | 52/2500 [00:01<00:53, 45.37it/s][A
  2%|▏         | 57/2500 [00:01<00:54, 45.19it/s][A
  2%|▏         | 62/2500 [00:01<00:54, 45.04it/s][A
  3%|▎         | 67/2500 [00:01<00:53, 45.20it/s][A
  3%|▎         | 72/2500 [00:01<00:53, 45.35it/s][A
  3%|▎         | 77/2500 [00:01<00:53, 45.02it/s][A
  3%|▎         | 82/2500 [00:01<00:53, 44.79it/s][A
  3%|▎         | 87/2500 [00:01<00:52, 45.63it/s][A
  4%|▎         | 92/2500 [00:02<00:53, 45.31it/s][A
  

 61%|██████    | 1518/2500 [00:33<00:21, 45.12it/s][A
 61%|██████    | 1523/2500 [00:33<00:21, 45.04it/s][A
 61%|██████    | 1528/2500 [00:33<00:21, 44.51it/s][A
 61%|██████▏   | 1533/2500 [00:33<00:21, 45.50it/s][A
 62%|██████▏   | 1539/2500 [00:33<00:20, 47.04it/s][A
 62%|██████▏   | 1544/2500 [00:34<00:20, 45.98it/s][A
 62%|██████▏   | 1549/2500 [00:34<00:20, 45.80it/s][A
 62%|██████▏   | 1554/2500 [00:34<00:20, 45.62it/s][A
 62%|██████▏   | 1559/2500 [00:34<00:20, 45.80it/s][A
 63%|██████▎   | 1564/2500 [00:34<00:20, 45.57it/s][A
 63%|██████▎   | 1569/2500 [00:34<00:20, 45.45it/s][A
 63%|██████▎   | 1574/2500 [00:34<00:20, 45.08it/s][A
 63%|██████▎   | 1579/2500 [00:34<00:20, 45.10it/s][A
 63%|██████▎   | 1584/2500 [00:34<00:20, 45.11it/s][A
 64%|██████▎   | 1589/2500 [00:35<00:20, 45.37it/s][A
 64%|██████▍   | 1594/2500 [00:35<00:20, 44.80it/s][A
 64%|██████▍   | 1599/2500 [00:35<00:20, 44.35it/s][A
 64%|██████▍   | 1604/2500 [00:35<00:20, 44.19it/s][A
 64%|█████

total count: 2500
Epoch 3 | Train Loss: 2.1429 | Train Accuracy: 0.2272
Validation Loss: 2.0846 | Validation Accuracy: 0.4000



  0%|          | 0/2500 [00:00<?, ?it/s][A
  0%|          | 7/2500 [00:00<00:41, 60.51it/s][A
  1%|          | 14/2500 [00:00<00:46, 53.73it/s][A
  1%|          | 20/2500 [00:00<00:47, 51.83it/s][A
  1%|          | 26/2500 [00:00<00:48, 50.93it/s][A
  1%|▏         | 32/2500 [00:00<00:47, 52.27it/s][A
  2%|▏         | 38/2500 [00:00<00:48, 51.09it/s][A
  2%|▏         | 44/2500 [00:00<00:48, 50.67it/s][A
  2%|▏         | 50/2500 [00:00<00:48, 50.43it/s][A
  2%|▏         | 56/2500 [00:01<00:48, 50.01it/s][A
  2%|▏         | 62/2500 [00:01<00:47, 51.35it/s][A
  3%|▎         | 68/2500 [00:01<00:47, 50.86it/s][A
  3%|▎         | 74/2500 [00:01<00:47, 50.56it/s][A
  3%|▎         | 80/2500 [00:01<00:48, 49.66it/s][A
  3%|▎         | 85/2500 [00:01<00:49, 48.76it/s][A
  4%|▎         | 90/2500 [00:01<00:49, 48.83it/s][A
  4%|▍         | 96/2500 [00:01<00:48, 49.43it/s][A
  4%|▍         | 102/2500 [00:02<00:46, 51.09it/s][A
  4%|▍         | 108/2500 [00:02<00:47, 50.50it/s][A


 68%|██████▊   | 1699/2500 [00:34<00:15, 50.79it/s][A
 68%|██████▊   | 1705/2500 [00:34<00:15, 50.45it/s][A
 68%|██████▊   | 1711/2500 [00:34<00:15, 50.22it/s][A
 69%|██████▊   | 1717/2500 [00:34<00:15, 49.49it/s][A
 69%|██████▉   | 1723/2500 [00:34<00:15, 51.01it/s][A
 69%|██████▉   | 1730/2500 [00:34<00:14, 53.66it/s][A
 69%|██████▉   | 1736/2500 [00:34<00:14, 52.33it/s][A
 70%|██████▉   | 1742/2500 [00:34<00:14, 51.58it/s][A
 70%|██████▉   | 1748/2500 [00:35<00:15, 49.59it/s][A
 70%|███████   | 1753/2500 [00:35<00:15, 49.26it/s][A
 70%|███████   | 1758/2500 [00:35<00:15, 48.66it/s][A
 71%|███████   | 1763/2500 [00:35<00:15, 48.85it/s][A
 71%|███████   | 1768/2500 [00:35<00:14, 48.84it/s][A
 71%|███████   | 1774/2500 [00:35<00:14, 49.53it/s][A
 71%|███████   | 1779/2500 [00:35<00:14, 49.51it/s][A
 71%|███████▏  | 1784/2500 [00:35<00:14, 48.71it/s][A
 72%|███████▏  | 1790/2500 [00:35<00:14, 49.66it/s][A
 72%|███████▏  | 1796/2500 [00:36<00:14, 49.83it/s][A
 72%|█████

total count: 2500
Epoch 4 | Train Loss: 2.0096 | Train Accuracy: 0.4924
Validation Loss: 1.9196 | Validation Accuracy: 0.6000
labels: tensor([118, 121], device='cuda:0') torch.Size([2])
predicted: tensor([118, 121], device='cuda:0') torch.Size([2])
labels: tensor([128, 122], device='cuda:0') torch.Size([2])
predicted: tensor([128, 122], device='cuda:0') torch.Size([2])
labels: tensor([120, 139], device='cuda:0') torch.Size([2])
predicted: tensor([120, 120], device='cuda:0') torch.Size([2])
labels: tensor([83, 74], device='cuda:0') torch.Size([2])
predicted: tensor([120,  74], device='cuda:0') torch.Size([2])
labels: tensor([57,  0], device='cuda:0') torch.Size([2])
predicted: tensor([120, 120], device='cuda:0') torch.Size([2])
Test Accuracy: 60 %


- 10 APs same as above x 500 times and batch size = 4

In [46]:
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 5


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train'], position=0, leave=True):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # label.size(0) is the batch size
        correct += (predicted == labels).sum().item() 
        # (predicted == labels).sum() would return how many of them are equal; 
        # .item() would make the tensor to the regular value
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/1250 [00:00<?, ?it/s][A
  0%|          | 6/1250 [00:00<00:21, 58.03it/s][A
  1%|          | 12/1250 [00:00<00:22, 54.19it/s][A
  1%|▏         | 18/1250 [00:00<00:23, 51.52it/s][A
  2%|▏         | 24/1250 [00:00<00:23, 51.17it/s][A
  2%|▏         | 30/1250 [00:00<00:24, 50.20it/s][A
  3%|▎         | 36/1250 [00:00<00:24, 49.51it/s][A
  3%|▎         | 42/1250 [00:00<00:23, 50.54it/s][A
  4%|▍         | 48/1250 [00:00<00:24, 49.70it/s][A
  4%|▍         | 54/1250 [00:01<00:23, 50.18it/s][A
  5%|▍         | 60/1250 [00:01<00:23, 49.81it/s][A
  5%|▌         | 66/1250 [00:01<00:23, 49.91it/s][A
  6%|▌         | 71/1250 [00:01<00:23, 49.58it/s][A
  6%|▌         | 77/1250 [00:01<00:23, 49.87it/s][A
  7%|▋         | 82/1250 [00:01<00:23, 49.39it/s][A
  7%|▋         | 88/1250 [00:01<00:23, 50.04it/s][A
  7%|▋         | 93/1250 [00:01<00:23, 49.29it/s][A
  8%|▊         | 99/1250 [00:01<00:22, 50.14it/s][A
  8%|▊         | 1

total count: 1250
Epoch 0 | Train Loss: 3.3005 | Train Accuracy: 0.1106
Validation Loss: 2.3991 | Validation Accuracy: 0.0833



  0%|          | 0/1250 [00:00<?, ?it/s][A
  0%|          | 6/1250 [00:00<00:23, 53.08it/s][A
  1%|          | 12/1250 [00:00<00:24, 50.57it/s][A
  1%|▏         | 18/1250 [00:00<00:24, 50.72it/s][A
  2%|▏         | 24/1250 [00:00<00:24, 49.31it/s][A
  2%|▏         | 30/1250 [00:00<00:24, 50.60it/s][A
  3%|▎         | 36/1250 [00:00<00:25, 47.63it/s][A
  3%|▎         | 41/1250 [00:00<00:25, 48.17it/s][A
  4%|▎         | 46/1250 [00:00<00:24, 48.28it/s][A
  4%|▍         | 51/1250 [00:01<00:25, 47.53it/s][A
  4%|▍         | 56/1250 [00:01<00:25, 47.67it/s][A
  5%|▍         | 62/1250 [00:01<00:24, 48.68it/s][A
  5%|▌         | 67/1250 [00:01<00:24, 48.31it/s][A
  6%|▌         | 73/1250 [00:01<00:23, 49.26it/s][A
  6%|▌         | 78/1250 [00:01<00:23, 48.88it/s][A
  7%|▋         | 84/1250 [00:01<00:23, 49.67it/s][A
  7%|▋         | 89/1250 [00:01<00:23, 49.48it/s][A
  8%|▊         | 95/1250 [00:01<00:23, 49.93it/s][A
  8%|▊         | 100/1250 [00:02<00:23, 49.29it/s][A
 

total count: 1250
Epoch 1 | Train Loss: 2.3053 | Train Accuracy: 0.1000
Validation Loss: 2.3156 | Validation Accuracy: 0.0833



  0%|          | 0/1250 [00:00<?, ?it/s][A
  0%|          | 6/1250 [00:00<00:22, 55.25it/s][A
  1%|          | 12/1250 [00:00<00:24, 49.81it/s][A
  1%|▏         | 18/1250 [00:00<00:24, 51.28it/s][A
  2%|▏         | 24/1250 [00:00<00:24, 49.99it/s][A
  2%|▏         | 30/1250 [00:00<00:24, 50.47it/s][A
  3%|▎         | 36/1250 [00:00<00:24, 49.75it/s][A
  3%|▎         | 42/1250 [00:00<00:24, 50.16it/s][A
  4%|▍         | 48/1250 [00:00<00:24, 49.50it/s][A
  4%|▍         | 54/1250 [00:01<00:23, 50.00it/s][A
  5%|▍         | 60/1250 [00:01<00:24, 49.57it/s][A
  5%|▌         | 66/1250 [00:01<00:23, 50.00it/s][A
  6%|▌         | 72/1250 [00:01<00:23, 49.40it/s][A
  6%|▌         | 77/1250 [00:01<00:23, 49.40it/s][A
  7%|▋         | 83/1250 [00:01<00:22, 50.75it/s][A
  7%|▋         | 89/1250 [00:01<00:22, 50.98it/s][A
  8%|▊         | 95/1250 [00:01<00:22, 50.22it/s][A
  8%|▊         | 101/1250 [00:02<00:22, 50.44it/s][A
  9%|▊         | 107/1250 [00:02<00:22, 50.03it/s][A


total count: 1250
Epoch 2 | Train Loss: 2.2643 | Train Accuracy: 0.1000
Validation Loss: 2.2837 | Validation Accuracy: 0.0833



  0%|          | 0/1250 [00:00<?, ?it/s][A
  0%|          | 6/1250 [00:00<00:24, 51.28it/s][A
  1%|          | 12/1250 [00:00<00:24, 49.62it/s][A
  1%|▏         | 18/1250 [00:00<00:24, 50.56it/s][A
  2%|▏         | 24/1250 [00:00<00:24, 49.08it/s][A
  2%|▏         | 29/1250 [00:00<00:25, 48.53it/s][A
  3%|▎         | 34/1250 [00:00<00:25, 48.09it/s][A
  3%|▎         | 40/1250 [00:00<00:24, 49.37it/s][A
  4%|▎         | 45/1250 [00:00<00:24, 49.08it/s][A
  4%|▍         | 51/1250 [00:01<00:24, 49.57it/s][A
  4%|▍         | 56/1250 [00:01<00:24, 49.38it/s][A
  5%|▍         | 61/1250 [00:01<00:24, 49.54it/s][A
  5%|▌         | 67/1250 [00:01<00:23, 49.72it/s][A
  6%|▌         | 72/1250 [00:01<00:23, 49.33it/s][A
  6%|▌         | 78/1250 [00:01<00:23, 49.87it/s][A
  7%|▋         | 83/1250 [00:01<00:23, 49.28it/s][A
  7%|▋         | 89/1250 [00:01<00:23, 49.96it/s][A
  8%|▊         | 94/1250 [00:01<00:23, 49.16it/s][A
  8%|▊         | 100/1250 [00:02<00:22, 50.04it/s][A
 

total count: 1250
Epoch 3 | Train Loss: 2.2305 | Train Accuracy: 0.1004
Validation Loss: 2.2466 | Validation Accuracy: 0.0833



  0%|          | 0/1250 [00:00<?, ?it/s][A
  0%|          | 6/1250 [00:00<00:23, 51.87it/s][A
  1%|          | 12/1250 [00:00<00:26, 47.43it/s][A
  1%|▏         | 18/1250 [00:00<00:25, 49.25it/s][A
  2%|▏         | 23/1250 [00:00<00:27, 45.22it/s][A
  2%|▏         | 28/1250 [00:00<00:26, 45.86it/s][A
  3%|▎         | 33/1250 [00:00<00:27, 44.24it/s][A
  3%|▎         | 39/1250 [00:00<00:26, 46.17it/s][A
  4%|▎         | 44/1250 [00:00<00:25, 46.62it/s][A
  4%|▍         | 49/1250 [00:01<00:25, 47.56it/s][A
  4%|▍         | 55/1250 [00:01<00:23, 49.93it/s][A
  5%|▍         | 61/1250 [00:01<00:23, 49.92it/s][A
  5%|▌         | 67/1250 [00:01<00:23, 50.14it/s][A
  6%|▌         | 73/1250 [00:01<00:23, 49.25it/s][A
  6%|▌         | 78/1250 [00:01<00:24, 48.77it/s][A
  7%|▋         | 84/1250 [00:01<00:23, 49.29it/s][A
  7%|▋         | 89/1250 [00:01<00:23, 49.04it/s][A
  8%|▊         | 94/1250 [00:01<00:23, 48.99it/s][A
  8%|▊         | 99/1250 [00:02<00:23, 48.54it/s][A
  

total count: 1250
Epoch 4 | Train Loss: 2.1824 | Train Accuracy: 0.1308
Validation Loss: 2.1909 | Validation Accuracy: 0.1667
labels: tensor([118, 121, 128, 122], device='cuda:0') torch.Size([4])
predicted: tensor([ 74,  74,  74, 122], device='cuda:0') torch.Size([4])
labels: tensor([120, 139,  83,  74], device='cuda:0') torch.Size([4])
predicted: tensor([74, 74, 74, 74], device='cuda:0') torch.Size([4])
labels: tensor([57,  0], device='cuda:0') torch.Size([2])
predicted: tensor([74, 74], device='cuda:0') torch.Size([2])
Test Accuracy: 20 %





### Model

In [8]:
class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_prob=0.25):
        super(GAT, self).__init__()
        
        # do not check the zero in_degree since we have all the complete graph
        self.layer1 = GATConv(in_dim, hidden_dim, num_heads=num_heads, activation=F.relu, allow_zero_in_degree=True)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, num_heads=num_heads, allow_zero_in_degree=True)
        self.layer3 = GATConv(hidden_dim * num_heads, out_dim, num_heads=num_heads, allow_zero_in_degree=True)
         
        # Adding Batch Normalization after each GAT layer
        self.batchnorm1 = nn.BatchNorm1d(hidden_dim * num_heads)
        self.batchnorm2 = nn.BatchNorm1d(hidden_dim * num_heads)
#         self.batchnorm3 = nn.BatchNorm1d(out_dim) # there's no need to use BN3
        
        # Adding Dropout for regularization
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, g, h):
        # Layer 1
        h1 = self.layer1(g, h)
        h1 = h1.view(h1.shape[0], -1)
#         h1 = self.batchnorm1(h1)
        h1 = F.relu(h1)
#         h1 = self.dropout(h1)
        
        # Layer 2
        h2 = self.layer2(g, h1)
        h2 = h2.view(h2.shape[0], -1)
#         h2 = self.batchnorm2(h2)
        h2 = F.relu(h2)
#         h2 = self.dropout(h2)

        # Layer 3
        h3 = self.layer3(g, h2).squeeze(1)
#         h3 = self.dropout(h3)
        
        '''
        問題出現在 h3 = self.layer3(g, h2).squeeze(1)。
        在這裡，你應該得到一個形狀為 [N, num_heads, out_dim] 的tensor，但你使用了 squeeze(1)，
        如果 num_heads 是 1，你會得到 [N, out_dim]，這樣是沒問題的。
        但如果 num_heads 不是 1，那麼squeeze操作不會更改tensor的形狀，結果仍然是 [N, num_heads, out_dim]。
        因此，對這個tensor使用 batch normalization 會導致維度不匹配。
        '''
        # output layer so not need the BN
        # 不使用BN: GAT本身已經有注意力機制，所以BN不一定是必需的，尤其是在輸出層。
        # h3 = self.batchnorm3(h3)
        


        # Aggregate
        g.ndata['h_out'] = h3
        h_agg = dgl.mean_nodes(g, feat='h_out')
        return h_agg

    

In [9]:
def model_fn(data, model, criterion, device, count=1):
    """Forward a batch through the model."""
    batched_g, labels = data
#     print(batch_g)
    batched_g = batched_g.to(device)
    
    labels = labels.to(device)
    logits = model(batched_g, batched_g.ndata['feat'].float()) # for GAT
    logits = logits.mean(dim=1)
#     print(logits)
    
    loss = criterion(logits, labels)
#     print(batched_g.ndata['feat'].dtype)
#     print("Logits shape:", logits.shape)  # Expected: (batch_size, 168)
#     print("Labels shape:", labels.shape)  # Expected: (batch_size)

    # Get the class id with the highest probability.
    preds = logits.argmax(1)
    
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

#     return loss, accuracy
    return loss, accuracy, preds

- 22 APs same as above x 5000 times and batch size = 4

In [14]:
seed = 8787
same_seeds(seed)

model = GAT(in_dim=50, hidden_dim=16, out_dim=168, num_heads=8)
# in_dim means the dimension of the node_feat(50 dim, since the 50-dim embedding)
# out_dim means the # of the categories -> 168 for out tasks
model.load_state_dict(torch.load('model_initial/initial_weight.pth'))

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

criterion = nn.CrossEntropyLoss()
total_steps = 18


# Training Part
for epoch in tqdm(range(total_steps)):
    # Train
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    count = 0 
    
    for data in tqdm(dataloaders['train'], position=0, leave=True):
        
        count += 1
        loss, accuracy, _ = model_fn(data, model, criterion, device, count)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
#     scheduler.step()
    print(f"total count: {count}")
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches

    print(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')

    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batched_g in dataloaders['valid']:
            loss, accuracy, _ = model_fn(batched_g, model, criterion, device)
            total_accuracy += accuracy.item()
            total_loss += loss.item()
            num_batches += 1

    avg_accuracy = total_accuracy / num_batches
    avg_loss = total_loss / num_batches
    print(f'Validation Loss: {avg_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}')


    # Save checkpoint
    if epoch%20 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, f"../checkpoint_GAT/checkpoint_{epoch}.pt")
    

# Testing Part
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in dataloaders['test']:
        loss, accuracy, predicted = model_fn(data, model, criterion, device)
        labels = data[1].to(device)  # Assuming labels are the second element in the tuple
        
        print(f"labels: {labels}", labels.shape)
        print(f"predicted: {predicted}", predicted.shape)
        
        total += labels.size(0) # label.size(0) is the batch size
        correct += (predicted == labels).sum().item() 
        # (predicted == labels).sum() would return how many of them are equal; 
        # .item() would make the tensor to the regular value
        
    print('Test Accuracy: %d %%' % (100 * correct / total))

100%|██████████| 27500/27500 [10:27<00:00, 43.85it/s]
 20%|██        | 1/5 [10:27<41:48, 627.23s/it]

total count: 27500
Epoch 0 | Train Loss: 3.5940 | Train Accuracy: 0.0259
Validation Loss: 3.1195 | Validation Accuracy: 0.0417


100%|██████████| 27500/27500 [10:29<00:00, 43.66it/s]
 40%|████      | 2/5 [20:57<31:26, 628.78s/it]

total count: 27500
Epoch 1 | Train Loss: 3.0977 | Train Accuracy: 0.0072
Validation Loss: 3.0915 | Validation Accuracy: 0.0417


  3%|▎         | 943/27500 [00:23<10:59, 40.25it/s]
 40%|████      | 2/5 [21:20<32:00, 640.26s/it]


KeyboardInterrupt: 