In [8]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path: sys.path.append("..")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import time
import torch
from src.agent import Agent
from src.files.file import create_directory
from src.utils.decorator import close_on_finish_decorator

In [15]:
import logging
# configure logging at the root level of lightning
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

In [16]:
def run_me(name,i_limit=5,base_config='base'):
    torch.cuda.empty_cache()
    agent = Agent(name, base_config=base_config)
    logged_metrics = []
    
    try:
        i = 0
        while True:
            print(f"Running model: {name}, Iteration: {i:2}/{i_limit if i_limit != -1 else agent._config['dataloader']['args']['split']['folds']:2}, date: {agent.creation_date}")
            agent.run()
            agent.save_hparams()
            with open(agent.trainer.logger.log_dir.rsplit("/",1)[0] + "/logged_metrics", 'a+') as f:
                f.write(str(agent.trainer.logged_metrics))
                
            logged_metrics.append(agent.trainer.logged_metrics)
            
            
            time.sleep(10)
            i += 1

            if not agent.dataloader.next_fold() or i == i_limit and i_limit != -1: break
            torch.cuda.empty_cache()
    except Exception as e:
        torch.cuda.empty_cache()
        raise e
    print(f"{name}: {logged_metrics}")

In [None]:
models = ["debug"]
for name in models:
    run_me(name, base_config='base',i_limit=-1)

Running model: debug, Iteration:  0/10, date: 20210618205727
Running model: debug, Iteration:  1/10, date: 20210618205727
Running model: debug, Iteration:  2/10, date: 20210618205727
Running model: debug, Iteration:  3/10, date: 20210618205727
Running model: debug, Iteration:  4/10, date: 20210618205727
Running model: debug, Iteration:  5/10, date: 20210618205727
Running model: debug, Iteration:  6/10, date: 20210618205727
Running model: debug, Iteration:  7/10, date: 20210618205727
Epoch 4 [24/57] {'loss': '1.03'}}

In [None]:
models = ["customCNN", "resnet18_brew2","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='base',i_limit=-1)

In [None]:
models = ["customCNN", "resnet18_brew2","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='augmented',i_limit=-1)

In [None]:
models = ["customCNN", "resnet18_brew2","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='no_augmented',i_limit=-1)

# Evaluate on the test data

In [None]:
checkpoints = [
    ("resnet18_brew2",),
    ("customCNN",),
    ("vgg11_brew",),
    ("vgg16",),
    ("resnet50",),
]
for name, path in checkpoints:
    agent = Agent(name, base_config='base',checkpoint_path=path)
    agent.load_model()
    agent.load_trainer()
    agent.trainer.test(agent.model, datamodule=agent.dataloader)