In [1]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path: sys.path.append("..")

In [2]:
import time
import torch
from src.agent import Agent
from src.files.file import create_directory
from src.utils.decorator import close_on_finish_decorator

In [3]:
import logging
# configure logging at the root level of lightning
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

In [4]:
def run_me(name,i_limit=5,base_config='base'):
    torch.cuda.empty_cache()
    agent = Agent(name, base_config=base_config)
    logged_metrics = []
    
    try:
        i = 0
        while True:
            print(f"Running model: {name}, Iteration: {i:2}/{i_limit if i_limit != -1 else agent._config['dataloader']['args']['split']['folds']:2}, date: {agent.creation_date}")
            agent.run()
            agent.save_hparams()
            with open(agent.trainer.logger.log_dir.rsplit("/",1)[0] + "/logged_metrics", 'a+') as f:
                f.write(str(agent.trainer.logged_metrics))
                
            logged_metrics.append(agent.trainer.logged_metrics)
            
            
            time.sleep(10)
            i += 1

            if not agent.dataloader.next_fold() or i == i_limit and i_limit != -1: break
            torch.cuda.empty_cache()
    except Exception as e:
        torch.cuda.empty_cache()
        raise e
    print(f"{name}: {logged_metrics}")

In [5]:
models = ["resnet18_brew2","customCNN","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='base',i_limit=-1)

Running model: resnet18_brew2, Iteration:  0/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  1/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  2/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  3/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  4/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  5/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  6/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  7/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  8/10, date: 20210620181026
Running model: resnet18_brew2, Iteration:  9/10, date: 20210620181026
resnet18_brew2: [{'epoch': tensor(30.), 'loss/train': tensor(0.9825, device='cuda:0'), 'loss/val': tensor(1.3711), 'AUROC/CN/val/': tensor(0.2348), 'AUROC/MCI/val/': tensor(0.4332), 'AUROC/AD/val/': tensor(0.4621), 'AUROC/val': tensor(0.3767), 'Accuracy/val': tensor(0.1071), 'Precision/val': tensor(

In [None]:
models = ["resnet18_brew2","customCNN","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='augmented',i_limit=-1)

In [None]:
models = ["resnet18_brew2","customCNN","vgg11_brew","vgg16","resnet50"]
for name in models:
    run_me(name, base_config='no_augmented',i_limit=-1)

# Evaluate on the test data

In [6]:
checkpoints = [
    ("resnet18_brew2","/var/metrics/codetests/logs/tb/final2/resnet18_brew2/20210619024739/version_8/checkpoints/epoch=70-step=4046.ckpt"),
    ("customCNN","/var/metrics/codetests/logs/tb/final2/customCNN/20210619002735/version_5/checkpoints/epoch=92-step=5300.ckpt"),
    ("vgg11_brew","/var/metrics/codetests/logs/tb/final2/vgg11_brew/20210619063455/version_0/checkpoints/epoch=30-step=1766.ckpt"),
    ("vgg16","/var/metrics/codetests/logs/tb/final2/vgg16/20210619101033/version_0/checkpoints/epoch=25-step=1481.ckpt"),
    ("resnet50","/var/metrics/codetests/logs/tb/final2/resnet50/20210619155131/version_5/checkpoints/epoch=99-step=5699.ckpt"),
]
for name, path in checkpoints:
    print(f"Running: {name} {path}")
    agent = Agent(name, base_config='base',checkpoint_path=path)
    agent.load_model()
    agent.load_trainer()
    agent.trainer.test(agent.model, datamodule=agent.dataloader)

Running: resnet18_brew2 /var/metrics/codetests/logs/tb/final2/resnet18_brew2/20210619024739/version_8/checkpoints/epoch=70-step=4046.ckpt
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'AUROC/AD/test/': 0.8133333325386047,
 'AUROC/CN/test/': 0.776190459728241,
 'AUROC/MCI/test/': 0.3583333492279053,
 'AUROC/test': 0.6492857336997986,
 'Accuracy/test': 0.4516128897666931,
 'Precision/test': 0.4516128897666931,
 'Recall/test': 0.4516128897666931,
 'Specificity/test': 0.725806474685669,
 'loss/test': 1.038294792175293}
--------------------------------------------------------------------------------
Running: customCNN /var/metrics/codetests/logs/tb/final2/customCNN/20210619002735/version_5/checkpoints/epoch=92-step=5300.ckpt
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'AUROC/AD/test/': 0.9066666960716248,
 'AUROC/CN/test/': 0.9142857193946838,
 'AUROC/MCI/test/': 0.6