In [1]:
import comet_ml

from pathlib import Path
from argparse import Namespace

import torch as th
import pytorch_lightning as pl

from lightningmodule import MULTModelWarped
from datasets import load_impressionv2_dataset_all

import mutils

In [9]:
project_name = 'find-lr-1'

In [3]:
comet_api = comet_ml.api.API()

In [4]:
[train_ds, valid_ds, test_ds], target_names = load_impressionv2_dataset_all()
train_dl = th.utils.data.DataLoader(
    train_ds, batch_size=8, pin_memory=True,
)
valid_dl = th.utils.data.DataLoader(
    valid_ds, batch_size=64, pin_memory=True,
)
test_dl = th.utils.data.DataLoader(
    test_ds, batch_size=64, pin_memory=True,
)

In [10]:
exps = comet_api.get("transformer", project_name)

In [16]:
def test_experiment(exp):
    experiement_key = exp.id
    project_name_dir = project_name.replace('-', '_')
    checkpoint_dir = Path('logs')/'weights'/f'{experiement_key}_{project_name_dir}'/f'0_{experiement_key}'/'checkpoints'
    ckpts = list(checkpoint_dir.glob('*.ckpt'))
    assert len(ckpts) == 1, len(ckpts)
    ckpt = ckpts[0]
    ck = th.load(ckpt)
    
    hyp_params = Namespace(**ck['hyper_parameters'])
    defaults = {'loss_fnc': 'L2', 'project_dim': 30, 'weight_decay': 0.0, 'optim': 'Adam'}
    for k, v in defaults.items():
        if k not in hyp_params:
            setattr(hyp_params, k, v)
    
    trainer = pl.Trainer(gpus=1)
    model = MULTModelWarped(hyp_params, target_names=target_names, early_stopping=None)
    model.load_state_dict(ck['state_dict'])
    test_res = trainer.test(model, test_dataloaders=test_dl)
    exp.log_metrics(test_res[0])
    
    df = mutils.get_exp_csv(experiement_key)
    df = mutils.get_epoch_info(df)
    exp.log_metric('best_valid_1mae', df['valid_1mae'].max())
    
    meta = exp.get_metadata()
    exp.set_end_time(meta['endTimeMillis'])

In [18]:
errors = []
for i, exp in enumerate(exps):
    try:
        test_experiment(exp)
    except:
        errors.append(exp.id)

In [14]:
errors

['3deff4f7ab184f788174c2a6185b443e',
 'fd29b0dd4b76449d8813316a04abfa26',
 '9ee58058267848b7b0a44e2f556f0633',
 '994ff687063641d489f731605812f7ba',
 '79e89414d18447818acf1124826afd57',
 '820bf18ee5d04c05af0551fa26b4f191']