In [23]:
import warnings
import torch

#import src.callbacks as clb
import src.configuration as C
import src.models as models
import src.utils as utils

from catalyst.dl import SupervisedRunner

from pathlib import Path

In [24]:
print(torch.cuda.is_available())

True


In [25]:
CONFIG_PATH = './configs/000_ResNet34.yml'
config = utils.load_config(CONFIG_PATH)

model = models.get_model(config)
print(model)


ResNet(
  (pre): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (right): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [None]:
warnings.filterwarnings("ignore")

# args = utils.get_parser().parse_args()

CONFIG_PATH = './configs/000_ResNet34.yml'
config = utils.load_config(CONFIG_PATH)

global_params = config["globals"]

output_dir = Path(global_params["output_dir"])
output_dir.mkdir(exist_ok=True, parents=True)
logger = utils.get_logger(output_dir / "output.log")

utils.set_seed(global_params["seed"])
device = C.get_device(global_params["device"])

df, datadir = C.get_metadata(config)
splitter = C.get_split(config)

for i, (trn_idx, val_idx) in enumerate(
        splitter.split(df, y=df["primary_label"])):
    if i not in global_params["folds"]:
        continue
    logger.info("=" * 20)
    logger.info(f"Fold {i}")
    logger.info("=" * 20)

    trn_df = df.loc[trn_idx, :].reset_index(drop=True)
    val_df = df.loc[val_idx, :].reset_index(drop=True)

    loaders = {
        phase: C.get_loader(df_, datadir, config, phase)
        for df_, phase in zip([trn_df, val_df], ["train", "valid"])
    }
    model = models.get_model(config).to(device)
    criterion = C.get_criterion(config).to(device)
    optimizer = C.get_optimizer(model, config)
    scheduler = C.get_scheduler(optimizer, config)
    # callbacks = clb.get_callbacks(config)

    runner = SupervisedRunner(
        # engine=device,
        input_key=global_params["input_key"],
        target_key=global_params["input_target_key"])
    runner.train(
        model=model,
        criterion=criterion,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=global_params["num_epochs"],
        verbose=True,
        logdir=output_dir / f"fold{i}",
        # callbacks=callbacks,
        # main_metric=global_params["main_metric"],
        # minimize_metric=global_params["minimize_metric"]
        )

2022-03-11 15:22:11,216 - INFO - logger set up
2022-03-11 15:22:11,321 - INFO - Fold 0


HBox(children=(FloatProgress(value=0.0, description='1/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (1/10) loss: 0.03927118124971935 | loss/mean: 0.03927118124971935 | loss/std: 0.0456498446951412 | lr: 0.001 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='1/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (1/10) loss: 0.03705415589837105 | loss/mean: 0.03705415589837105 | loss/std: 0.0073270180192717945 | lr: 0.001 | momentum: 0.9
* Epoch (1/10) lr: 0.0009755282581475768 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='2/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (2/10) loss: 0.030240078176621093 | loss/mean: 0.030240078176621093 | loss/std: 0.0022330228867710796 | lr: 0.0009755282581475768 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='2/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (2/10) loss: 0.029012889432329318 | loss/mean: 0.029012889432329318 | loss/std: 0.007491275968130138 | lr: 0.0009755282581475768 | momentum: 0.9
* Epoch (2/10) lr: 0.0009045084971874736 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='3/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (3/10) loss: 0.02699777929933975 | loss/mean: 0.02699777929933975 | loss/std: 0.0023740333967774657 | lr: 0.0009045084971874736 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='3/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (3/10) loss: 0.029286321568563298 | loss/mean: 0.029286321568563298 | loss/std: 0.00859627617083498 | lr: 0.0009045084971874736 | momentum: 0.9
* Epoch (3/10) lr: 0.0007938926261462366 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='4/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (4/10) loss: 0.024508349205516266 | loss/mean: 0.024508349205516266 | loss/std: 0.0024859479847625784 | lr: 0.0007938926261462366 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='4/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (4/10) loss: 0.025162207259248295 | loss/mean: 0.025162207259248295 | loss/std: 0.008064945969845546 | lr: 0.0007938926261462366 | momentum: 0.9
* Epoch (4/10) lr: 0.0006545084971874737 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='5/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (5/10) loss: 0.022577901922279844 | loss/mean: 0.022577901922279844 | loss/std: 0.0026587588313367843 | lr: 0.0006545084971874737 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='5/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (5/10) loss: 0.025397493182679927 | loss/mean: 0.025397493182679927 | loss/std: 0.009155119796383919 | lr: 0.0006545084971874737 | momentum: 0.9
* Epoch (5/10) lr: 0.0005 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='6/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (6/10) loss: 0.020773557551326205 | loss/mean: 0.020773557551326205 | loss/std: 0.002566016323051552 | lr: 0.0005 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='6/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (6/10) loss: 0.02782814597785573 | loss/mean: 0.02782814597785573 | loss/std: 0.011172960368263228 | lr: 0.0005 | momentum: 0.9
* Epoch (6/10) lr: 0.00034549150281252633 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='7/10 * Epoch (train)', max=372.0, style=ProgressStyle(des…


train (7/10) loss: 0.0191893193899943 | loss/mean: 0.0191893193899943 | loss/std: 0.0024829580306757336 | lr: 0.00034549150281252633 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='7/10 * Epoch (valid)', max=93.0, style=ProgressStyle(desc…


valid (7/10) loss: 0.020212687034121022 | loss/mean: 0.020212687034121022 | loss/std: 0.007834591485356035 | lr: 0.00034549150281252633 | momentum: 0.9
