In [1]:
import torch
from torch_geometric.loader import DataLoader
from src.training.train import train
from src.model.gae import GAE
from src.data.loader import GraphDataset
from src.utils import commons

device = commons.get_device()
torch.cuda.empty_cache()
config = commons.get_config('configs/reduced.yaml')
model = GAE(config=config['model']).to(device)

# optimizer
if config['training']['optimizer']['type'] == 'Adam':
    from torch.optim.adam import Adam
    optimizer = Adam(model.parameters(), lr=config['training']['optimizer']['learning_rate'])
elif config['training']['optimizer']['type'] == 'AdamW':
    from torch.optim.adamw import AdamW
    optimizer = AdamW(model.parameters(), lr=config['training']['optimizer']['learning_rate'])
else:
    raise ValueError(f"Invalid optimizer: {config['training']['optimizer']['type']}")

# scheduler 
if config['training']['scheduler']['type'] == 'StepLR':
    from torch.optim.lr_scheduler import StepLR
    scheduler = StepLR(optimizer, step_size=config['training']['scheduler']['step_size'], gamma=config['training']['scheduler']['gamma'])
elif config['training']['scheduler']['type'] == 'CosineAnnealingLR':
    from torch.optim.lr_scheduler import CosineAnnealingLR
    scheduler = CosineAnnealingLR(optimizer, T_max=config['training']['epochs'])
else:
    raise ValueError(f"Invalid scheduler: {config['training']['scheduler']['type']}")

dataset = GraphDataset()
train_loader = DataLoader(dataset=dataset, 
                          batch_size=config['training']['batch_size'], 
                          shuffle=False,
                          num_workers=config['config']['num_workers'])


In [2]:
train_history = train(model=model, 
                      optimizer=optimizer,
                      device=device,
                      scheduler=scheduler, 
                      train_loader=train_loader, 
                      config=config)

  0%|          | 0/100 [00:00<?, ?it/s]

  edge_index = torch.tensor(self.edge_list, dtype=torch.long),
  edge_attr = torch.tensor(self.edge_features, dtype=torch.float32),
  edge_weight = torch.tensor(self.edge_weights, dtype=torch.float32))
  0%|          | 0/100 [00:01<?, ?it/s]


RuntimeError: shape '[1, 128]' is invalid for input of size 1974144

In [2]:
for test_batch in train_loader:
    # print the max / min of edge_index
    print(f"feature shape: {test_batch.x.shape}")
    print(f"pos shape: {test_batch.pos.shape}")
    print(f"edge_attr shape: {test_batch.edge_attr.shape}")
    print(f"edge_index shape: {test_batch.edge_index.shape}")
    print(f"Max: {test_batch.edge_index.max().item()}, Min: {test_batch.edge_index.min().item()}")
    print(f"edge_index : {test_batch.edge_index[:,:6]}")
    print(f"edge_weight shape: {test_batch.edge_weight.shape}")
    break

feature shape: torch.Size([30846, 2])
pos shape: torch.Size([30846, 2])
edge_attr shape: torch.Size([46015])
edge_index shape: torch.Size([2, 46015])
Max: 30845, Min: 0
edge_index : tensor([[    0,     0,     0,     1,     1,     1],
        [23316,     5, 30417,   182,  1260, 14815]])
edge_weight shape: torch.Size([46015])


  edge_index = torch.tensor(self.edge_list, dtype=torch.long),
  edge_attr = torch.tensor(self.edge_features, dtype=torch.float32),
  edge_weight = torch.tensor(self.edge_weights, dtype=torch.float32))


In [1]:
from torch_geometric.nn import GCNConv, ChebConv, GATConv, GMMConv
from src.data.loader import GraphDataset

# Create dataset and get a batch
dataset = GraphDataset()
batch = dataset[0]

# Print shapes and values
print("Edge index shape:", batch.edge_index.shape)
print("Edge index max value:", batch.edge_index.max().item())
print("Edge index min value:", batch.edge_index.min().item())
print("Edge weight shape:", batch.edge_weight.shape)
print("Number of nodes:", batch.x.shape[0])

# Try GCN with edge weights
test_model = GCNConv(2, 32)
res = test_model(batch.x, batch.edge_index, batch.edge_weight)

Edge index shape: torch.Size([2, 92538])
Edge index max value: 15677
Edge index min value: 0
Edge weight shape: torch.Size([92538])
Number of nodes: 15677


RuntimeError: index 15677 is out of bounds for dimension 0 with size 15677