In [1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

1.13.0


In [2]:
# !git clone -b augmentation "https://github.com/ab7289-tandon-nyu/GraphVision.git"
# !cp -r /content/GraphVision/src/ .

In [3]:
SEED = 1234

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
from src.transforms import get_transforms
from src.datasets import get_datasets, get_dataloaders

transforms = get_transforms("cifar10-slic")
train_dataset, valid_dataset, test_dataset = get_datasets(".data/", 
                        "torchvision_cifar10", pre_transforms = None,
                        transforms = transforms)
BATCH_SIZE = 128

train_loader, valid_loader, test_loader = get_dataloaders(train_dataset,
                                                          valid_dataset,
                                                          test_dataset,
                                                          batch_size=(BATCH_SIZE, 1, 1),
                                                          drop_last = False)

ValueError: Invalid Dataset name

In [5]:
d = train_dataset[0]

print(d)

(Data(x=[102, 3], pos=[102, 2], edge_index=[2, 918], edge_weight=[918], edge_attr=[918, 2]), 8)


In [6]:
sample_batch = next(iter(train_loader))

print()
print(sample_batch)



[DataBatch(x=[13647, 3], pos=[13647, 2], edge_index=[2, 122823], edge_weight=[122823], edge_attr=[122823, 2], batch=[13647], ptr=[129]), tensor([9, 3, 2, 3, 5, 7, 9, 0, 0, 8, 0, 1, 1, 3, 9, 7, 3, 2, 7, 6, 9, 2, 4, 2,
        5, 3, 8, 6, 6, 3, 1, 5, 0, 2, 7, 9, 1, 3, 8, 1, 0, 6, 5, 2, 9, 9, 9, 6,
        1, 8, 6, 8, 0, 5, 1, 5, 5, 9, 7, 0, 9, 4, 7, 3, 4, 6, 4, 4, 1, 4, 9, 4,
        7, 3, 7, 0, 7, 3, 1, 6, 4, 8, 2, 4, 2, 4, 2, 0, 0, 9, 2, 7, 0, 4, 3, 3,
        0, 9, 6, 9, 0, 8, 8, 1, 7, 7, 7, 1, 1, 9, 5, 8, 7, 7, 9, 9, 7, 4, 1, 3,
        2, 9, 1, 5, 4, 2, 8, 6])]


In [7]:
# store edge dimension
edge_dim = sample_batch[0].edge_attr.size(-1)
edge_dim
# store number of features in graph batch
input_features = sample_batch[0].x.size(-1)
# store number of classes for classification
num_classes = 10

print(f"Number of features: {input_features}")
print(f"Number of classes: {num_classes}")
print(f"Edge Attr Dimension: {edge_dim}")
print(f"Steps per epoch: {len(train_loader)}")

Number of features: 3
Number of classes: 10
Edge Attr Dimension: 2
Steps per epoch: 352


In [8]:
from src.models import DeeperGCN
from src.engine import evaluate

hidden_features = 256

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperGCN(
    input_features,
    num_classes,
    hidden_features,
    conv_type = "General",
    act = "relu",
    norm = "layer",
    num_layers = 16,
    use_cluster_pooling = False,
    readout = "mean",
    dropout = 0.1,
    edge_dim = edge_dim
).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
prev_loss = None
# if state_dict is not None:
#   print("Loading previously saved state dictionary")
#   model.load_state_dict(state_dict)
#   prev_loss, _ = evaluate(model.to(device), test_loader, criterion, device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=0.001)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.001, verbose=True, cycle_momentum=False)

In [9]:
params = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(f"There are {params:,} trainable parameters.")
print()
print(model)

There are 1,076,746 trainable parameters.

DeeperGCN(
  (fc_in): Linear(in_features=3, out_features=256, bias=True)
  (fc_out): Linear(in_features=256, out_features=10, bias=True)
  (out_act): ReLU(inplace=True)
  (layers): ModuleList(
    (0): DeepGCNLayer(block=res+)
    (1): DeepGCNLayer(block=res+)
    (2): DeepGCNLayer(block=res+)
    (3): DeepGCNLayer(block=res+)
    (4): DeepGCNLayer(block=res+)
    (5): DeepGCNLayer(block=res+)
    (6): DeepGCNLayer(block=res+)
    (7): DeepGCNLayer(block=res+)
    (8): DeepGCNLayer(block=res+)
    (9): DeepGCNLayer(block=res+)
    (10): DeepGCNLayer(block=res+)
    (11): DeepGCNLayer(block=res+)
    (12): DeepGCNLayer(block=res+)
    (13): DeepGCNLayer(block=res+)
    (14): DeepGCNLayer(block=res+)
    (15): DeepGCNLayer(block=res+)
  )
)


In [10]:
from src.engine import train, evaluate
from src.utils import calculate_accuracy

path = "model.pt"

EPOCHS = 50
train_loss = []
train_acc = []
valid_loss = []
valid_acc = []
best_loss = float('inf')
best_loss = float('inf')
if prev_loss is not None:
  print(f"Training from previous best loss: {prev_loss}")
  best_loss = prev_loss

for epoch in range(1, EPOCHS + 1):
  print(f"\nEpoch: {epoch}\n")
  loss, acc = train(model, train_loader, criterion, optimizer, device)
  train_loss.append(loss)
  train_acc.append(acc)
  print(f"Train Loss: {train_loss[-1]:.3f}, Train Accuracy: {train_acc[-1]:.2f}")

  # scheduler.step()

  loss, acc = evaluate(model, valid_loader, criterion, device)
  valid_loss.append(loss)
  valid_acc.append(acc)
  print(f"Validation Loss: {valid_loss[-1]:.3f}, Validation Accuracy: {valid_acc[-1]:.2f}")

  if loss < best_loss:
    best_loss = loss
    torch.save(model.state_dict(), path)


Epoch: 1



TypeError: ignored

In [None]:
# load best model
model.load_state_dict(torch.load(path))
test_loss, test_acc = evaluate(model.to(device), test_loader, criterion, device)
print(f"Test Loss: {test_loss:.3f}")
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
import matplotlib.pyplot as plt

x = [i for i in range(len(valid_loss))]

plt.plot(x, train_loss, label="Train Loss")
plt.plot(x, valid_loss, label="Validation Loss")
plt.legend()
plt.show()

In [None]:
plt.plot(x, train_acc, label="Train Accuracy")
plt.plot(x, valid_acc, label="Validation Accuracy")
plt.legend()
plt.show()