In [1]:
import os
import sys
import optuna
import shutil
import os.path as osp
import pytorch_lightning as pl

from pathlib import Path
from argparse import ArgumentParser, Namespace
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from datamodule import DataModule
from supervised import SupervisedLightningModule,inverseSupervisedLightningModule

import sys
sys.path.append('../')
from utils import run_cli, yaml_func

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def main(config_path) -> None:
    config = run_cli(config_path=config_path)
    seeds = config['seeds']
    for seed in seeds:
        if seed is not None:
            pl.seed_everything(seed)

        ckpt_callback = ModelCheckpoint(
            filename='{epoch}-{val_acc:.2f}',
            **config['ckpt_callback'],
        )
        if 'callbacks' in config['trainer_params']:
            config['trainer_params']['callbacks'] = yaml_func(
                config['trainer_params']['callbacks'])
        if config['trainer_params']['default_root_dir'] == "None":
            config['trainer_params']['default_root_dir'] = osp.dirname(__file__)
        
        
        if config["hparams"]["inverse"]==0:
            model = SupervisedLightningModule(config)
        elif config["hparams"]["inverse"]==1:
            model = inverseSupervisedLightningModule(config)

        logger = TensorBoardLogger(
            save_dir=config['logger']['save_dir'],
            name=config['logger']['name']+f"-seed{seed}",
            version=config['logger']['version'],)
        
        dest_dir = os.path.join(config['logger']['save_dir'], config['logger']['name']+f"-seed{seed}", f"{config['logger']['version']}")

        trainer = pl.Trainer(**config['trainer_params'],
                            callbacks=[ckpt_callback],
                            logger=logger)
        imdm = DataModule(
            train_dir=config['dataset']['train_dir'],
            val_dir=config['dataset']['val_dir'],
            batch_size=config['hparams']['batch_size'])
        trainer.fit(model, datamodule=imdm)
        shutil.copy(config_path, f'{dest_dir}/config.yaml')

In [5]:
if __name__ == '__main__':
  default_config_path = './configs/config.yaml'
  main(default_config_path)


Global seed set to 4
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | ResNet9          | 6.6 M 
1 | ce_loss | CrossEntropyLoss | 0     
---------------------------------------------
6.6 M     Trainable params
0         Non-trainable params
6.6 M     Total params
26.486    Total estimated model params size (MB)


                                                                      

Global seed set to 4


Epoch 0:  83%|████████▎ | 196/236 [00:29<00:05,  6.71it/s, loss=1.85, v_num=dam|]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▍ | 198/236 [00:29<00:05,  6.70it/s, loss=1.85, v_num=dam|]
Epoch 0:  85%|████████▌ | 201/236 [00:29<00:05,  6.76it/s, loss=1.85, v_num=dam|]
Validating:  12%|█▎        | 5/40 [00:00<00:03, 11.50it/s][A
Epoch 0:  86%|████████▋ | 204/236 [00:30<00:04,  6.83it/s, loss=1.85, v_num=dam|]
Epoch 0:  88%|████████▊ | 207/236 [00:30<00:04,  6.90it/s, loss=1.85, v_num=dam|]
Validating:  28%|██▊       | 11/40 [00:00<00:01, 16.89it/s][A
Epoch 0:  89%|████████▉ | 210/236 [00:30<00:03,  6.96it/s, loss=1.85, v_num=dam|]
Epoch 0:  90%|█████████ | 213/236 [00:30<00:03,  7.02it/s, loss=1.85, v_num=dam|]
Validating:  42%|████▎     | 17/40 [00:01<00:01, 18.70it/s][A
Epoch 0:  92%|█████████▏| 216/236 [00:30<00:02,  7.09it/s, loss=1.85, v_num=dam|]
Epoch 0:  93%|█████████▎| 219/236 [00:30<00:02,  7.15it/s, loss=1.85

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
