In [1]:
import torch
import pytorch_lightning as pl
from model import rnn
import os
from dataset import MnistDataModule
import default_config as config
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profilers import PyTorchProfiler

torch.set_float32_matmul_precision("medium")

print("Number of CPUs available: ", torch.get_num_threads())

Number of CPUs available:  12


In [2]:
# Additional RNN cell configurations
kwargs_dict = {
    "Wr_identity": False,
    "learn_tau": True,
    "dt_tau_max_y": 0.03,
    "dt_tau_max_a": 0.01,
    "dt_tau_max_b": 0.1,
}
# kwargs_dict = {}

In [3]:
# change the accellerator to cpu for pixel by pixel mnist
# config.ACCELERATOR = "cpu"
# config.DEVICES = 1
config.RESIZE = 1.0
config.INPUT_SIZE = 1
config.SEQUENCE_LENGTH = 784
config.LEARNING_RATE = 0.01
config.ACCELERATOR = "gpu"
config.HIDDEN_SIZE = 256

In [6]:
start_from_checkpoint = False
model_name = "sMNIST_pixel_by_pixel_test"
folder_name = "tb_logs_pixel"
# model_name = "sMNIST_normal"
logger = TensorBoardLogger(folder_name, name=model_name)
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{folder_name}/profiler0"),
#     schedule=torch.profiler.schedule(skip_first=10, wait=1, warmup=1, active=20),
# )
dm = MnistDataModule(
    data_dir=config.DATA_DIR,
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    permuted=config.PERMUTED,
    resize=config.RESIZE,
)
model = rnn(
    input_size=config.INPUT_SIZE,
    hidden_size=config.HIDDEN_SIZE,
    seq_length=config.SEQUENCE_LENGTH,
    learning_rate=config.LEARNING_RATE,
    scheduler_change_step=config.SCHEDULER_CHANGE_STEP,
    scheduler_gamma=config.SCHEDULER_GAMMA,
    num_classes=config.NUM_CLASSES, 
    kwargs_dict=kwargs_dict,
)
trainer = pl.Trainer(
    profiler=None,
    logger=logger,
    accelerator=config.ACCELERATOR,
    callbacks=[LearningRateMonitor(logging_interval='epoch'),
               ModelCheckpoint(save_top_k=-1, every_n_epochs=1)],
    devices=config.DEVICES,
    min_epochs=1,
    max_epochs=2,
    precision=config.PRECISION
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
if start_from_checkpoint:
    version = 1
    checkpoint_folder = f'{folder_name}/{model_name}/version_{version}/checkpoints/'
    checkpoint_files = os.listdir(checkpoint_folder)
    epoch_idx = [int(file.split('epoch=')[1].split('-')[0]) for file in checkpoint_files]
    max_idx = epoch_idx.index(max(epoch_idx))
    checkpoint_path = os.path.join(checkpoint_folder, checkpoint_files[max_idx])
    trainer.fit(model, dm, ckpt_path=checkpoint_path)
else:
    trainer.fit(model, dm)
# trainer.validate(model, dm)
# trainer.test(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | org      | rnnCell            | 395 K 
1 | fc       | Linear             | 2.6 K 
2 | loss_fn  | CrossEntropyLoss   | 0     
3 | accuracy | MulticlassAccuracy | 0     
4 | f1_score | MulticlassF1Score  | 0     
------------------------------------------------
397 K     Trainable params
512       Non-trainable params
397 K     Total params
1.591     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [8]:
trainer.test(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7231000065803528
         test_f1            0.7231000065803528
        test_loss           0.8231350779533386
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.7231000065803528,
  'test_f1': 0.7231000065803528,
  'test_loss': 0.8231350779533386}]