In [1]:
import torch_geometric

In [2]:
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
torch.multiprocessing.set_start_method('spawn')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = xm.xla_device(n=2, devkind='TPU')
print(device.type)
print(torch.cuda.get_device_name(0))

import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

from sklearn.manifold import TSNE

cuda
NVIDIA A100 80GB PCIe MIG 2g.20gb


In [3]:
train_graphs = torch.load("train_graphs_2.pt")

for ele in train_graphs:
    ele['feature'] = torch.cat((ele['pos'],ele['value'].unsqueeze(1)),1)

In [4]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_graphs, batch_size=64, shuffle=True)

In [5]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(3, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 4)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        # x = x.reshape(-1, 1)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GCN(hidden_channels=40)
print(model)

GCN(
  (conv1): GCNConv(3, 40)
  (conv2): GCNConv(40, 40)
  (conv3): GCNConv(40, 40)
  (lin): Linear(in_features=40, out_features=4, bias=True)
)


In [8]:
model = GCN(hidden_channels=40)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data.to(device)
        out = model(data.feature, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data.to(device)
        out = model(data.feature, data.edge_index, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


In [9]:
for epoch in range(1, 171):
    print(f'Epoch: {epoch:03d}')  
    train()
    train_acc = test(train_loader)
    # test_acc = test(test_loader)
    # print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    #
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}')

Epoch: 001
Epoch: 001, Train Acc: 0.4453
Epoch: 002
Epoch: 002, Train Acc: 0.5529
Epoch: 003
Epoch: 003, Train Acc: 0.6562
Epoch: 004
Epoch: 004, Train Acc: 0.6642
Epoch: 005
Epoch: 005, Train Acc: 0.7394
Epoch: 006
Epoch: 006, Train Acc: 0.7295
Epoch: 007
Epoch: 007, Train Acc: 0.7578
Epoch: 008
Epoch: 008, Train Acc: 0.7124
Epoch: 009
Epoch: 009, Train Acc: 0.7615
Epoch: 010
Epoch: 010, Train Acc: 0.7643
Epoch: 011
Epoch: 011, Train Acc: 0.7395
Epoch: 012
Epoch: 012, Train Acc: 0.7732
Epoch: 013
Epoch: 013, Train Acc: 0.7736
Epoch: 014
Epoch: 014, Train Acc: 0.7777
Epoch: 015
Epoch: 015, Train Acc: 0.7552
Epoch: 016
Epoch: 016, Train Acc: 0.7841
Epoch: 017
Epoch: 017, Train Acc: 0.7723
Epoch: 018
Epoch: 018, Train Acc: 0.7227
Epoch: 019
Epoch: 019, Train Acc: 0.7452
Epoch: 020
Epoch: 020, Train Acc: 0.7886
Epoch: 021
Epoch: 021, Train Acc: 0.7695
Epoch: 022
Epoch: 022, Train Acc: 0.7313
Epoch: 023
Epoch: 023, Train Acc: 0.7720
Epoch: 024
Epoch: 024, Train Acc: 0.7680
Epoch: 025
Epoch