# Base NN Architectures

In [1]:
import torch
from torch import nn
from scaling.models.model_factory import MODELS
from scaling.train import train

[rank: 0] Seed set to 42


In [2]:
def count_params(model):
    return sum([x.numel() for x in model.parameters() if x.requires_grad])

@torch.no_grad()
def check_model(model, channels=12, seq_len=1024):
    """Runs one forward on random data."""
    _ = model(torch.rand(1, channels, seq_len))

In [3]:
print("MODEL NAME\t|\tPARAMETERS")
print("-" * 35)
for model_name, model_fn in MODELS.items():
    model = model_fn()
    params = count_params(model)
    print(f"{model_name}\t|\t{params:,}")
    check_model(model)

MODEL NAME	|	PARAMETERS
-----------------------------------
convnext_mini	|	13,383,098
convnext_tiny	|	26,787,770
convnext_small	|	48,132,026
convnext_base	|	85,458,842
convnext_large	|	192,036,698
resnet18	|	3,862,170
resnet50	|	16,012,442
resnet101	|	28,319,898
resnext18	|	12,867,482
resnext50	|	22,086,042
resnext101	|	79,676,826
vit_tiny	|	6,512,666
vit_small	|	25,616,922
vit_base	|	85,641,242
vit_large	|	303,054,362
getemed_small	|	952,770
getemed_base	|	3,160,770
getemed_large	|	13,558,210


# Dry run all models

In [4]:
path = "/sc-scratch/sc-scratch-gbm-radiomics/ecg/physionet_challenge/training_pt/metadata_v4.csv"
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
for key in MODELS.keys():
    train(
        project="test",
        name="test_run_1",
        meta_file_path=path,
        fold=0,
        model_name=key,
        loss_fn=loss_fn,
        lr_decay_gamma=0.95,
        fast_dev_run=True,
    )

/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 141.49it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.64it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.64it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ConvNeXt          | 26.8 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
26.8 M    Trainable params
0         Non-trainable params
26.8 M    Total params
107.151   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  5.14it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 107.91it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.14it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ConvNeXt          | 48.1 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
48.1 M    Trainable params
0         Non-trainable params
48.1 M    Total params
192.528   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.98it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 70.57it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.56it/s]                [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.55it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ConvNeXt          | 85.5 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
85.5 M    Trainable params
0         Non-trainable params
85.5 M    Total params
341.835   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.81it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 73.32it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.38it/s]                [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.37it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ConvNeXt          | 192 M  | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
192 M     Trainable params
0         Non-trainable params
192 M     Total params
768.147   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.52it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 77.28it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]                [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 3.9 M  | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.449    Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  4.50it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 144.95it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 16.0 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
16.0 M    Trainable params
0         Non-trainable params
16.0 M    Total params
64.050    Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  4.37it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 111.55it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.08it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.08it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 28.3 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
28.3 M    Trainable params
0         Non-trainable params
28.3 M    Total params
113.280   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  4.84it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 69.90it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]                [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 12.9 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.470    Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  5.17it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 135.21it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.36it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.35it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 22.1 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
22.1 M    Trainable params
0         Non-trainable params
22.1 M    Total params
88.344    Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 100.79it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]                 [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | ResNet            | 79.7 M | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
79.7 M    Trainable params
0         Non-trainable params
79.7 M    Total params
318.707   Total estimated model par

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  3.86it/s]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 65.63it/s][A
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s]                [A

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  2.63it/s]


/home/jabareen/.conda/envs/ecg/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jabareen/.conda/envs/ecg/lib/python3.12/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
