In [6]:
from src.datasets import stellar_data
from src.models import vanilla_stellar
import torch
import numpy as np
import argparse

In [7]:
torch.manual_seed(42)
sample_idx_permutations = []
n_folds = 5
n_train = 100
for i in range(n_folds):
    idx_perm = torch.randperm(125).tolist()
    sample_idx_permutations.append({"train" : idx_perm[:n_train], "valid" : idx_perm[n_train:]})

In [12]:
cfg_reduced = argparse.Namespace(**{
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'input_dim': 41,
    'hid_dim': 512, # originally 128
    'num_classes': 14,
    'lr': 1e-3,
})

In [13]:
n_epochs = 20

In [14]:
valid_accs = np.zeros(n_folds)
for k in range(n_folds):
    print(f"Fold {k} of {n_folds}")
    train_dataloader = stellar_data.StellarDataloader(
        filename='stellar_graph_dataset_train.pt',
        batch_size=512,
        shuffle=True,
        graphs_idx=sample_idx_permutations[k]["train"],
        test=False
    )
    valid_dataloader = stellar_data.StellarDataloader(
        filename='stellar_graph_dataset_train.pt',
        batch_size=512,
        shuffle=False,
        graphs_idx=sample_idx_permutations[k]["valid"],
        test=False
    )

    vanilla_stellar_reduced = vanilla_stellar.VanillaStellarReduced(cfg_reduced)
    valid_acc = vanilla_stellar_reduced.train(
        train_loader=train_dataloader,
        valid_loader=valid_dataloader,
        epochs=n_epochs,
        return_valid_acc=True)
    valid_accs[k] = valid_acc

Fold 0 of 5


Training - epoch 0: 100%|██████████| 1/1 [00:12<00:00, 12.18s/it, Loss=2.42, Accuracy=0.314]
Validation - epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.03s/it, Loss=2.67, Accuracy=0.548]
Training - epoch 1: 100%|██████████| 1/1 [00:08<00:00,  8.58s/it, Loss=2.93, Accuracy=0.505]
Validation - epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s, Loss=2.41, Accuracy=0.548]
Training - epoch 2: 100%|██████████| 1/1 [00:08<00:00,  8.09s/it, Loss=2.63, Accuracy=0.505]
Validation - epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.41it/s, Loss=1.92, Accuracy=0.548]
Training - epoch 3: 100%|██████████| 1/1 [00:09<00:00,  9.19s/it, Loss=2.07, Accuracy=0.505]
Validation - epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.36it/s, Loss=1.87, Accuracy=0.55]
Training - epoch 4: 100%|██████████| 1/1 [00:08<00:00,  8.32s/it, Loss=1.95, Accuracy=0.506]
Validation - epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s, Loss=2.04, Accuracy=0.383]
Training - epoch 5: 100%|██████████| 1/1 [00:08<00:00,  8.23s

Fold 1 of 5


Training - epoch 0: 100%|██████████| 1/1 [00:09<00:00,  9.74s/it, Loss=2.93, Accuracy=0.0126]
Validation - epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it, Loss=2.58, Accuracy=0.482]
Training - epoch 1: 100%|██████████| 1/1 [00:09<00:00,  9.25s/it, Loss=2.39, Accuracy=0.521]
Validation - epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s, Loss=2.43, Accuracy=0.482]
Training - epoch 2: 100%|██████████| 1/1 [00:09<00:00,  9.31s/it, Loss=2.26, Accuracy=0.521]
Validation - epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s, Loss=2.02, Accuracy=0.482]
Training - epoch 3: 100%|██████████| 1/1 [00:08<00:00,  8.14s/it, Loss=1.9, Accuracy=0.521]
Validation - epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.26it/s, Loss=1.91, Accuracy=0.482]
Training - epoch 4: 100%|██████████| 1/1 [00:07<00:00,  7.99s/it, Loss=1.83, Accuracy=0.522]
Validation - epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s, Loss=1.97, Accuracy=0.484]
Training - epoch 5: 100%|██████████| 1/1 [00:09<00:00,  9.40

Fold 2 of 5


Training - epoch 0: 100%|██████████| 1/1 [00:09<00:00,  9.82s/it, Loss=2.63, Accuracy=0.0278]
Validation - epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s, Loss=1.96, Accuracy=0.633]
Training - epoch 1: 100%|██████████| 1/1 [00:09<00:00,  9.71s/it, Loss=2.74, Accuracy=0.482]
Validation - epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s, Loss=1.73, Accuracy=0.633]
Training - epoch 2: 100%|██████████| 1/1 [00:09<00:00,  9.60s/it, Loss=2.41, Accuracy=0.482]
Validation - epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s, Loss=1.54, Accuracy=0.633]
Training - epoch 3: 100%|██████████| 1/1 [00:08<00:00,  8.94s/it, Loss=2.01, Accuracy=0.482]
Validation - epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s, Loss=1.77, Accuracy=0.634]
Training - epoch 4: 100%|██████████| 1/1 [00:08<00:00,  8.35s/it, Loss=2.07, Accuracy=0.484]
Validation - epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s, Loss=1.91, Accuracy=0.655]
Training - epoch 5: 100%|██████████| 1/1 [00:08<00:00,  8.6

Fold 3 of 5


Training - epoch 0: 100%|██████████| 1/1 [00:07<00:00,  7.87s/it, Loss=2.35, Accuracy=0.525]
Validation - epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.39it/s, Loss=3.8, Accuracy=0.366]
Training - epoch 1: 100%|██████████| 1/1 [00:08<00:00,  8.07s/it, Loss=2.67, Accuracy=0.551]
Validation - epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s, Loss=3.56, Accuracy=0.366]
Training - epoch 2: 100%|██████████| 1/1 [00:08<00:00,  8.94s/it, Loss=2.48, Accuracy=0.551]
Validation - epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s, Loss=2.88, Accuracy=0.366]
Training - epoch 3: 100%|██████████| 1/1 [00:09<00:00,  9.59s/it, Loss=2.01, Accuracy=0.551]
Validation - epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s, Loss=2.43, Accuracy=0.366]
Training - epoch 4: 100%|██████████| 1/1 [00:08<00:00,  8.30s/it, Loss=1.79, Accuracy=0.551]
Validation - epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.12it/s, Loss=2.37, Accuracy=0.371]
Training - epoch 5: 100%|██████████| 1/1 [00:08<00:00,  8.32s

Fold 4 of 5


Training - epoch 0: 100%|██████████| 1/1 [00:08<00:00,  8.58s/it, Loss=2.63, Accuracy=0.051]
Validation - epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s, Loss=2.5, Accuracy=0.54]
Training - epoch 1: 100%|██████████| 1/1 [00:08<00:00,  8.69s/it, Loss=2.68, Accuracy=0.507]
Validation - epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s, Loss=2.21, Accuracy=0.54]
Training - epoch 2: 100%|██████████| 1/1 [00:09<00:00,  9.05s/it, Loss=2.32, Accuracy=0.507]
Validation - epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.24it/s, Loss=1.91, Accuracy=0.54]
Training - epoch 3: 100%|██████████| 1/1 [00:08<00:00,  8.43s/it, Loss=1.93, Accuracy=0.507]
Validation - epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s, Loss=2.08, Accuracy=0.553]
Training - epoch 4: 100%|██████████| 1/1 [00:08<00:00,  8.84s/it, Loss=2.04, Accuracy=0.518]
Validation - epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s, Loss=2.13, Accuracy=0.488]
Training - epoch 5: 100%|██████████| 1/1 [00:08<00:00,  8.63s/it

In [15]:
print(f"Mean validation accuracy: {valid_accs.mean()} +/- {valid_accs.std()}")

Mean validation accuracy: 0.5245523750782013 +/- 0.08812878765431796
