In [1]:
import os

if os.getcwd().endswith("notebooks"):
    os.chdir("..")
    print("using project root as working dir")

using project root as working dir


In [7]:
from dataclasses import dataclass
import numpy as np
import networkx as nx
import math
from tqdm.notebook import tqdm
import random
from typing import List, Tuple


@dataclass
class Args:
    random_seed = None
    # torch
    batch_size = 64
    epochs = 30
    layers = 10
    layer_size = 16
    train_size = 0.7
    wandb = False
    # graph
    graph_size = 1000
    graph_shape = 'disc'
    rg_radius = 0.05
    # dataset manipulation
    ds_padded = True

In [9]:
Node = Tuple[float, float]
Nodes = List[Node]
NodeIndexPairs = List[Tuple[int, int]]

def gen_nodes(args: Args) -> Nodes:
    if args.graph_shape == 'disc':
        return __gen_nodes_disc(args.graph_size)
    else:
        raise f'unsupported node shape: {args.graph_shape}'


def __gen_nodes_disc(amount: int) -> Nodes:
    points = []
    with tqdm(total=amount, desc="generating random-uniform nodes on disc") as pbar:
        while len(points) < amount:
            p = (random.uniform(0, 1), random.uniform(0, 1))
            d = (p[0] - 0.5, p[1] - 0.5)
            if math.sqrt(d[0] * d[0] + d[1] * d[1]) > 0.5:
                continue
            points.append(p)
            pbar.update(1)
    return points


def get_node_pairs(n_nodes: int) -> NodeIndexPairs:
    return [
        (i0, i1)
        for i0 in tqdm(range(n_nodes), desc="generating node pairs")
        for i1 in range(i0 + 1, n_nodes)
    ]


# https://stackoverflow.com/a/36460020/10619052
def list_to_dict(items: list) -> dict:
    return {v: k for v, k in enumerate(tqdm(items, desc="creating dict from list"))}

In [10]:
import torch
from torch import nn
from torch.utils.data import Dataset, TensorDataset, DataLoader


# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
#device = "cpu"
print(f"using {device} device")

using cuda device


In [14]:
# Define dataset
class GraphDataset:
    def __init__(self, args: Args):
        # generate graph
        self.nodes = gen_nodes(args)
        self.n_nodes = len(self.nodes)
        self.graph = nx.random_geometric_graph(
            self.n_nodes,
            args.rg_radius,
            pos=list_to_dict(self.nodes)
        )
        self.node_index_pairs = get_node_pairs(self.n_nodes)
        # generate dataset
        ds_values = torch.tensor([
            [*self.nodes[i0], *self.nodes[i1]] # type: [float, float, float, float]
            for (i0, i1) in tqdm(self.node_index_pairs, desc="generating dataset values from node pairs")
        ])
        ds_labels = torch.LongTensor([
            1 if self.graph.has_edge(i0, i1) else 0
            for (i0, i1) in tqdm(self.node_index_pairs, desc="generating dataset labels from node pairs")
        ])
        self.dataset = TensorDataset(ds_values, ds_labels)

In [28]:
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        #self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(4, args.layer_size),
            nn.ReLU(),
            nn.Linear(args.layer_size, args.layer_size),
            nn.ReLU(),
            nn.Linear(args.layer_size, 2)
        )

    def forward(self, x):
        #x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train(dataloader: DataLoader, model: nn.Module, loss_fn, optimizer):
    n_batches = len(dataloader)
    model.train()
    pbar = tqdm(total=n_batches, desc="starting train...")
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update progress
        if batch % 100 == 0 or batch == n_batches - 1:
            pbar.update(batch - pbar.n)
            pbar.set_postfix_str(f"loss: {loss.item():>7f}")
            pbar.set_description(f"batch: {batch}")


def test(dataloader: DataLoader, model: nn.Module, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [35]:
args = Args()

graph_dataset = GraphDataset(args)
full_dataset = graph_dataset.dataset
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [args.train_size, 1 - args.train_size]) ## do we ant to over-fit?

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=0, shuffle=False)

model = NeuralNetwork().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(args.epochs):
    with tqdm(total=len(train_dataloader), desc="starting model...") as pbar:
        pbar.set_description(f"Epoch {epoch + 1}")

        # train
        model.train()
        n_train_batches = len(train_dataloader)
        intv = int(n_train_batches / 100) # interval in which the pbar is updated (every 1%)
        for batch, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update progress
            if batch % intv == 0 or batch == n_train_batches - 1:
                pbar.update(batch - pbar.n)
                pbar.set_postfix_str(f"loss: {loss.item():>6f}")

        # test
        model.eval()
        n_test_batches = len(test_dataloader)
        n_test_values = len(test_dataloader.dataset)
        test_loss, correct = 0, 0
        pbar.set_postfix_str(f"evaluating epoch...")
        with torch.no_grad():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= n_test_batches
        correct /= n_test_values
        pbar.set_postfix_str(f"epoch result: accuracy: {(100*correct):>0.1f}%, avg_loss: {test_loss:>8f}")


print("Done!")

generating random-uniform nodes on disc:   0%|          | 0/1000 [00:00<?, ?it/s]

creating dict from list:   0%|          | 0/1000 [00:00<?, ?it/s]

generating node pairs:   0%|          | 0/1000 [00:00<?, ?it/s]

generating dataset values from node pairs:   0%|          | 0/499500 [00:00<?, ?it/s]

generating dataset labels from node pairs:   0%|          | 0/499500 [00:00<?, ?it/s]

full size: 499500 | train size: 349650 | test size: 149850
NeuralNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
)


starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

starting model...:   0%|          | 0/5464 [00:00<?, ?it/s]

Done!
