# Train

In [2]:
import wandb
from paperswithtopic.config import load_config
from paperswithtopic.run import run

In [None]:
cfg = load_config()
cfg.use_saved = True
cfg.pre_embed = False
cfg.use_bert_embed = False

cfg.model_name = 'bertforclassification'

In [2]:
wandb.login()
wandb.init(project='paperswithtopic', name='bertforclassification')

In [None]:
run(cfg)

# Hyperparameter grid search with wandb.sweep

In [None]:
sweep_config = {
    'name': 'bertforclassification_re',
    'method': 'grid',
    'metric': {
        'name': 'valid_auc',
        'goal': 'maximize',
    },
    'parameters': {
        'hidden_dim': {
            'values': [128, 256, 512]
        },
        'n_layers': {
            'values': [3, 4]
        },
        'n_heads': {
            'values': [8, 16, 32]
        }
    }
}
sweep_id = wandb.sweep(sweep_config, project='paperswithtopic')

In [None]:
def run_sweep():
    
    with wandb.init():
        
        cfg = load_config()
        
        cfg.use_saved = True
        cfg.pre_embed = False
        cfg.use_bert_embed = False        
        cfg.model_name = 'bert'
        cfg.learning_rate = 0.0001
        
        _cfg = wandb.config
        __cfg = dict(); __cfg.update(_cfg); cfg.update(__cfg)
        
        name = f'SWEEP_LR{cfg.learning_rate}_DIM{cfg.hidden_dim}'
        
        wandb.run.name = name
        wandb.config.update(cfg)
        
        run(cfg)
        
        clear_output()

In [None]:
wandb.agent(sweep_id, function=run_sweep)