In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import optuna
from hydra import compose, initialize
from hydra.utils import instantiate
import wandb
import torch
import time
from omegaconf import OmegaConf
from utils.io_utils import get_resume_file, hydra_setup, fix_seed, model_to_dict, opt_to_dict, get_model_file
from run import train, test
from prettytable import PrettyTable
import pickle

In [None]:
def objective(trial):
    # Hydra context manager
    with initialize(config_path="conf"):
        # Compose the configuration
        cfg = compose(config_name="main", overrides=[
            f"exp.name=optuna_trial_{trial.number}",
            f"optimizer_cls.lr={trial.suggest_loguniform('lr', 1e-7, 1e-4)}",
        ])

        # Print the configuration for debugging
        print(OmegaConf.to_yaml(cfg))

        # # Your existing model initialization and training logic
        # train_loader, val_loader, model = initialize_dataset_model(cfg)
        # model = train(train_loader, val_loader, model, cfg)

        # # Evaluate the model and return the metric you are interested in
        # acc_mean, _ = test(cfg, model, 'val')  # Assuming 'val' is your validation split
        # return acc_mean

In [4]:
# Hydra context manager
with initialize(config_path="conf"):
    dataset = 'swissprot_no_backbone' 
    method = 'protoformer'
    num_sub_support = 5
    n_layer = 2
    n_head = 2
    contrastive_coef = 1
    norm_first = False
    weight_decay = 0.01
    dropout = 0.1
    contrastive_loss = 'original'

    # Compose the configuration with overrides
    cfg = compose(config_name="main", overrides=[
        f"model={method}", 
        f"method={method}",
        f"dataset={dataset}",  # Example dataset
        f"lr={0.00001}",  # Example learning rate
        f"weight_decay={weight_decay}",
        f"method.cls.n_sub_support={num_sub_support}",
        f"method.cls.n_layer={n_layer}",
        f"method.cls.n_head={n_head}",
        f"method.cls.contrastive_coef={contrastive_coef}",
        f"method.cls.dropout={dropout}",
        f"method.cls.norm_first={norm_first}",
        f"method.cls.contrastive_loss={contrastive_loss}",
        f"exp.name=optuna_trial_{1}",
        f"method.cls.ffn_dim={512}",
        # Add more overrides if necessary
    ])

    # Print the configuration for debugging
    print(OmegaConf.to_yaml(cfg))

dataset:
  type: classification
  simple_cls:
    _target_: datasets.prot.swissprot.SPSimpleDataset
  set_cls:
    n_way: ${n_way}
    n_support: ${n_shot}
    n_query: ${n_query}
    _target_: datasets.prot.swissprot.SPSetDataset
    embed_dir: ${dataset.embed_dir}
  name: swissprot_no_bacbone
  embed_dir: embeds
eval_split:
- train
- val
- test
backbone:
  _target_: backbones.id.Id
train_classes: 59
n_way: 5
n_shot: 5
n_query: 15
method:
  name: protoformer
  train_batch: null
  val_batch: null
  fast_weight: false
  start_epoch: 0
  eval_type: set
  stop_epoch: 60
  type: meta
  cls:
    n_way: ${n_way}
    n_support: ${n_shot}
    _target_: methods.protoformer.ProtoFormer
    n_layer: 2
    n_head: 2
    contrastive_coef: 1
    n_sub_support: 5
    ffn_dim: 512
    dropout: 0.1
    norm_first: false
    contrastive_loss: original
model: protoformer
mode: train
exp:
  name: optuna_trial_1
  save_freq: 10
  resume: false
  seed: 42
  val_freq: 1
optimizer: Adam
lr: 1.0e-05
weight_dec

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="conf"):


In [None]:
def tune(dataset="swissprot_no_backbone"):

    with initialize(config_path="conf/dataset", version_base=None):
        cfg = compose(config_name=dataset)
        
    train_dataset = instantiate(cfg.dataset.set_cls, mode='train')
    val_dataset = instantiate(cfg.dataset.set_cls, mode='val')

    train_loader = train_dataset.get_data_loader()
    val_loader = val_dataset.get_data_loader()

    results = []

    def objective(trial):
        with initialize(config_path="conf"):
            # Compose the configuration with trial-specific overrides
            cfg = compose(config_name="main", overrides=[
                "model=protoformer",
                "method=protoformer",
                f"dataset={dataset}",  # Fixed dataset
                f"optimizer_cls.lr={trial.suggest_float('lr', 1e-7, 1e-4)}",
                f"optimizer_cls.weight_decay={trial.suggest_float('weight_decay', 1e-5, 1e-3)}",
                f"method.cls.n_sub_support={trial.suggest_int('n_sub_support', 2, 4)}",
                f"method.cls.n_layer={trial.suggest_int('n_layer', 1, 3)}",
                f"method.cls.n_head={trial.suggest_categorical('n_head', [1, 2, 3, 4, 5, 8])}",
                f"method.cls.contrastive_coef={trial.suggest_float('contrastive_coef', 0.1, 2.0)}",
                f"method.cls.dropout={trial.suggest_float('dropout', 0.0, 0.5)}",
                f"method.cls.norm_first={trial.suggest_categorical('norm_first', [True, False])}",
                f"method.cls.contrastive_loss=original", #TODO: you can modify this
                f"exp.name=optuna_trial_{trial.number}",
            ])

            fix_seed(cfg.exp.seed)

            print(OmegaConf.to_yaml(cfg))

            # Initialize model and backbone for this trial
            backbone = instantiate(cfg.backbone, x_dim=train_dataset.dim)
            model = instantiate(cfg.method.cls, backbone=backbone)

            if torch.cuda.is_available():
                model = model.cuda()

            model = train(train_loader, val_loader, model, cfg)

            acc_mean, acc_std = test(cfg, model, split='val')

            results.append([trial.number, acc_mean, acc_std])

            return acc_mean  # or any other metric you want to optimize

    # Run Optuna study
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=2)

    # Output the optimization results
    best_trial = study.best_trial
    print(f"Best Trial: {best_trial.number}")
    print(f"Best Value: {best_trial.value}")
    print(f"Best Parameters: {best_trial.params}")

    # Save the study
    optuna_studies_file = f"{dataset}_studies.pkl"
    with open(optuna_studies_file, "wb") as f:
        pickle.dump(study, f)

    # Log results to WandB
    table = wandb.Table(data=results, columns=["trial", "acc_mean", "acc_std"])
    wandb.log({"eval_results": table})

    # Display results in a pretty table
    display_table = PrettyTable(["trial", "acc_mean", "acc_std"])
    for row in results:
        display_table.add_row(row)
    print(display_table)

In [None]:
tune()

In [None]:
# with initialize(config_path="conf/dataset", version_base=None):
#     cfg = compose(config_name="swissprot_no_backbone")
#     print(OmegaConf.to_yaml(cfg))