In [1]:
import os
if os.getcwd().endswith('notebooks'):
    os.chdir(os.path.dirname(os.getcwd()))
    
from src.datasets import stellar_data
from src.models import vanilla_stellar
import torch
import numpy as np
import argparse



In [2]:
GRAPH_DATASET_FILENAME = 'stellar_graph_dataset.pt'

In [3]:
torch.manual_seed(42)
sample_idx_permutations = []
n_folds = 5
n_train = 100
for i in range(n_folds):
    idx_perm = torch.randperm(125).tolist()
    sample_idx_permutations.append({"train" : idx_perm[:n_train], "valid" : idx_perm[n_train:]})

In [7]:
cfg_reduced = argparse.Namespace(**{
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'input_dim': 40,
    'hid_dim': 128, # originally 128
    'num_classes': 14,
    'lr': 1e-3,
})

In [8]:
batch_size = 1
n_epochs = 20

In [9]:
valid_accs = np.zeros(n_folds)
for k in range(n_folds):
    print(f"Fold {k} of {n_folds}")
    train_dataloader = stellar_data.StellarDataloader(
        filename=GRAPH_DATASET_FILENAME,
        batch_size=batch_size,
        shuffle=True,
        graphs_idx=sample_idx_permutations[k]["train"],
        test=False
    )
    valid_dataloader = stellar_data.StellarDataloader(
        filename=GRAPH_DATASET_FILENAME,
        batch_size=batch_size,
        shuffle=False,
        graphs_idx=sample_idx_permutations[k]["valid"],
        test=False
    )

    vanilla_stellar_reduced = vanilla_stellar.VanillaStellarReduced(cfg_reduced)
    valid_acc = vanilla_stellar_reduced.train(
        train_loader=train_dataloader,
        valid_loader=valid_dataloader,
        epochs=n_epochs,
        return_valid_acc=True)
    valid_accs[k] = valid_acc

Fold 0 of 5


Training - epoch 0: 100%|██████████| 100/100 [00:02<00:00, 47.84it/s, Loss=1.23, Accuracy=0.647]
Validation - epoch 0: 100%|██████████| 25/25 [00:00<00:00, 87.68it/s, Loss=0.615, Accuracy=0.815] 
Training - epoch 1: 100%|██████████| 100/100 [00:01<00:00, 56.22it/s, Loss=0.539, Accuracy=0.834]
Validation - epoch 1: 100%|██████████| 25/25 [00:00<00:00, 130.17it/s, Loss=0.369, Accuracy=0.888]
Training - epoch 2: 100%|██████████| 100/100 [00:01<00:00, 52.86it/s, Loss=0.417, Accuracy=0.864]
Validation - epoch 2: 100%|██████████| 25/25 [00:00<00:00, 130.35it/s, Loss=0.363, Accuracy=0.89] 
Training - epoch 3: 100%|██████████| 100/100 [00:01<00:00, 55.03it/s, Loss=0.366, Accuracy=0.882]
Validation - epoch 3: 100%|██████████| 25/25 [00:00<00:00, 131.67it/s, Loss=0.296, Accuracy=0.908]
Training - epoch 4: 100%|██████████| 100/100 [00:01<00:00, 52.41it/s, Loss=0.321, Accuracy=0.895]
Validation - epoch 4: 100%|██████████| 25/25 [00:00<00:00, 132.22it/s, Loss=0.307, Accuracy=0.908]
Training - epoch

Fold 1 of 5


Training - epoch 0: 100%|██████████| 100/100 [00:01<00:00, 59.58it/s, Loss=1.23, Accuracy=0.648]
Validation - epoch 0: 100%|██████████| 25/25 [00:00<00:00, 121.59it/s, Loss=0.696, Accuracy=0.78] 
Training - epoch 1: 100%|██████████| 100/100 [00:01<00:00, 55.21it/s, Loss=0.491, Accuracy=0.851]
Validation - epoch 1: 100%|██████████| 25/25 [00:00<00:00, 122.75it/s, Loss=0.484, Accuracy=0.844]
Training - epoch 2: 100%|██████████| 100/100 [00:01<00:00, 58.15it/s, Loss=0.383, Accuracy=0.877]
Validation - epoch 2: 100%|██████████| 25/25 [00:00<00:00, 119.03it/s, Loss=0.375, Accuracy=0.878]
Training - epoch 3: 100%|██████████| 100/100 [00:01<00:00, 54.87it/s, Loss=0.353, Accuracy=0.887]
Validation - epoch 3: 100%|██████████| 25/25 [00:00<00:00, 117.06it/s, Loss=0.404, Accuracy=0.861]
Training - epoch 4: 100%|██████████| 100/100 [00:01<00:00, 53.98it/s, Loss=0.32, Accuracy=0.896]
Validation - epoch 4: 100%|██████████| 25/25 [00:00<00:00, 119.77it/s, Loss=0.341, Accuracy=0.886]
Training - epoch 

Fold 2 of 5


Training - epoch 0: 100%|██████████| 100/100 [00:01<00:00, 55.42it/s, Loss=1.09, Accuracy=0.696]
Validation - epoch 0: 100%|██████████| 25/25 [00:00<00:00, 130.15it/s, Loss=0.509, Accuracy=0.836]
Training - epoch 1: 100%|██████████| 100/100 [00:01<00:00, 54.28it/s, Loss=0.497, Accuracy=0.844]
Validation - epoch 1: 100%|██████████| 25/25 [00:00<00:00, 127.74it/s, Loss=0.351, Accuracy=0.893]
Training - epoch 2: 100%|██████████| 100/100 [00:01<00:00, 57.15it/s, Loss=0.395, Accuracy=0.871]
Validation - epoch 2: 100%|██████████| 25/25 [00:00<00:00, 125.80it/s, Loss=0.269, Accuracy=0.915]
Training - epoch 3: 100%|██████████| 100/100 [00:01<00:00, 51.73it/s, Loss=0.341, Accuracy=0.891]
Validation - epoch 3: 100%|██████████| 25/25 [00:00<00:00, 104.11it/s, Loss=0.27, Accuracy=0.909] 
Training - epoch 4: 100%|██████████| 100/100 [00:01<00:00, 52.36it/s, Loss=0.335, Accuracy=0.89]
Validation - epoch 4: 100%|██████████| 25/25 [00:00<00:00, 63.44it/s, Loss=0.251, Accuracy=0.917] 
Training - epoch 

Fold 3 of 5


Training - epoch 0: 100%|██████████| 100/100 [00:01<00:00, 53.78it/s, Loss=1.18, Accuracy=0.67]
Validation - epoch 0: 100%|██████████| 25/25 [00:00<00:00, 119.38it/s, Loss=0.899, Accuracy=0.695]
Training - epoch 1: 100%|██████████| 100/100 [00:01<00:00, 53.77it/s, Loss=0.517, Accuracy=0.84]
Validation - epoch 1: 100%|██████████| 25/25 [00:00<00:00, 124.76it/s, Loss=0.562, Accuracy=0.828]
Training - epoch 2: 100%|██████████| 100/100 [00:01<00:00, 56.73it/s, Loss=0.358, Accuracy=0.886]
Validation - epoch 2: 100%|██████████| 25/25 [00:00<00:00, 122.96it/s, Loss=0.445, Accuracy=0.858]
Training - epoch 3: 100%|██████████| 100/100 [00:02<00:00, 49.40it/s, Loss=0.313, Accuracy=0.9] 
Validation - epoch 3: 100%|██████████| 25/25 [00:00<00:00, 121.05it/s, Loss=0.414, Accuracy=0.866]
Training - epoch 4: 100%|██████████| 100/100 [00:01<00:00, 53.05it/s, Loss=0.293, Accuracy=0.906]
Validation - epoch 4: 100%|██████████| 25/25 [00:00<00:00, 121.55it/s, Loss=0.404, Accuracy=0.862]
Training - epoch 5:

Fold 4 of 5


Training - epoch 0: 100%|██████████| 100/100 [00:02<00:00, 34.53it/s, Loss=1.16, Accuracy=0.671]
Validation - epoch 0: 100%|██████████| 25/25 [00:00<00:00, 94.92it/s, Loss=0.666, Accuracy=0.794]
Training - epoch 1: 100%|██████████| 100/100 [00:02<00:00, 42.56it/s, Loss=0.493, Accuracy=0.846]
Validation - epoch 1: 100%|██████████| 25/25 [00:00<00:00, 89.73it/s, Loss=0.641, Accuracy=0.811]
Training - epoch 2: 100%|██████████| 100/100 [00:02<00:00, 40.55it/s, Loss=0.367, Accuracy=0.885]
Validation - epoch 2: 100%|██████████| 25/25 [00:00<00:00, 80.92it/s, Loss=0.378, Accuracy=0.871]
Training - epoch 3: 100%|██████████| 100/100 [00:02<00:00, 40.06it/s, Loss=0.332, Accuracy=0.893]
Validation - epoch 3: 100%|██████████| 25/25 [00:00<00:00, 74.76it/s, Loss=0.361, Accuracy=0.884]
Training - epoch 4: 100%|██████████| 100/100 [00:02<00:00, 39.02it/s, Loss=0.3, Accuracy=0.901] 
Validation - epoch 4: 100%|██████████| 25/25 [00:00<00:00, 57.93it/s, Loss=0.352, Accuracy=0.885]
Training - epoch 5: 10

In [10]:
print(f"Mean validation accuracy: {valid_accs.mean()} +/- {valid_accs.std()}")

Mean validation accuracy: 0.9114213490486144 +/- 0.019208156609532963
