In [2]:
from torch_geometric.datasets import TUDataset
import sys
sys.path.append('..')
from dhg import Hypergraph
from hypgs.models.hsn_pyg import HSN
from hypgs.utils.data import HGDataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import random_split
from pytorch_lightning.loggers import TensorBoardLogger


In [3]:
original_dataset = TUDataset(root='../data/', name="MUTAG", use_node_attr=True)
to_hg_func = lambda g: Hypergraph.from_graph_kHop(g, 1)
dataset = HGDataset(original_dataset, to_hg_func)

In [4]:
# Assuming 'dataset' is your PyTorch Dataset
dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.15 * dataset_size)
test_size = dataset_size - (train_size + val_size)

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
model = HSN(in_channels=7, 
              hidden_channels=16,
              out_channels = 2, 
              trainable_laziness = False,
              trainable_scales = False, 
              activation = 'modulus', 
              fixed_weights=True, 
              layout=['hsm'], 
              normalize='right', 
              pooling='max',
              task='classification',
        )

In [5]:
# Early stopping callback based on validation loss
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10, verbose=True, mode='min')
logger = TensorBoardLogger("lightning_logs", name="my_model")

trainer = pl.Trainer(
    max_epochs=100,  # Adjust as necessary
    logger=logger,
    callbacks=[early_stopping_callback]
)

# Assuming you have defined train_loader and val_loader
trainer.fit(model, train_loader, val_loader)  # Corrected argument names
# Test the model
trainer.test(model, test_loader)

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.1 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs/my_model
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | layers      | ModuleList         | 102   
1 | batch_norm  | BatchNorm          | 84    
2 | fc1         | Linear             | 903   
3 | fc2         | Linear             | 44    
4 | relu        | ReLU               | 0     
5 | batch_norm1 | BatchNorm1d        | 42    
6 | mlp         | Sequential

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 516. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages

Training: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 578. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 563. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 580. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.684
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 576. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 551. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 575. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to inf

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.684
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 552. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 608. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 610. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utiliti

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.679
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 573. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 584. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 565. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utiliti

Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 574. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 557. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 604. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection

Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 593. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 559. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 55. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.662
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 600. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 564. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 577. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utiliti

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 0.649
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 571. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 579. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 582. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utiliti

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.639
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 566. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 592. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 553. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utiliti

Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 556. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 572. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 63. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 558. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 605. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 45. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 554. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 594. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 560. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 618. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 567. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 548. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 569. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 607. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 536. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 602. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection

Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 634. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 561. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 514. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection

Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 562. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 601. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 52. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.639. Signaling Trainer to stop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.7586206793785095
     test_auc_epoch          0.773809552192688
     test_loss_epoch        1.6698592901229858
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


/home/sumry2023_cqx3/.conda/envs/hyper/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 505. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss_epoch': 1.6698592901229858,
  'test_acc_epoch': 0.7586206793785095,
  'test_auc_epoch': 0.773809552192688}]

In [6]:
# to see the logs run
!tensorboard --logdir lightning_logs

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [6]:
dataset[0].x.dtype

torch.float32