### Introduction

In this notebook, I trained the steerable architecture on all channel subsets.

In [None]:
!apt-get update

In [None]:
!apt-get install build-essential libatomic1 gfortran perl wget m4 cmake pkg-config curl -y

In [None]:
import wandb
import utilities.metadata as metadata
from dataset.dataset import PlanetaryDataset
import os
import torch
import wandb
from utilities.metadata import CHANNEL_SUBSETS
from utilities.training import TRAIN_TRANSFORM, run_sweeps_for_channel_subsets

### Wandb initialization

In [None]:
wandb.login(key="SUPERSECRETKEY")
sweep_id = wandb.sweep(metadata.SWEEP_CONFIG, project="eq_colab_2")
wandb.init()
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

### Dataset

In [None]:
epochs = 10
percentage_of_dataset = 1

In [None]:
dataset = PlanetaryDataset(
    data_dir="/kaggle/input/gsoc-protoplanetary-disks/Train_Clean",
    csv_file="/kaggle/input/gsoc-protoplanetary-disks/train_info_cleaned.csv",
    channels=[],
    transform=TRAIN_TRANSFORM,
)

### Training, Evaluation

In [None]:
trained_models = {}

run_sweeps_for_channel_subsets(
    percentage_of_dataset, trained_models, CHANNEL_SUBSETS, sweep_id, dataset
)

### Save models to file

In [None]:
save_dir = "best_saved_models"
os.makedirs(save_dir, exist_ok=True)
for key, val in trained_models.items():
    model_file_path = os.path.join(
        save_dir, f"model_channels_{'_'.join(map(str, key))}.pt"
    )
    torch.save(val[0].state_dict(), model_file_path)
    print(f"Model saved to {model_file_path}")

In [None]:
wandb.finish()