Import packages

In [4]:
import lightning.pytorch as pl 
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import torch.nn.functional as F
from torch import optim
from models.deq import FullModel,FullModel2
from helpers import SwissGridDataModule


pl.seed_everything(42, workers=True)

Global seed set to 42


42

In [5]:
class Model(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=['model'])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, y)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss
    
    def forward(self,x):
        return self.model(x)
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x_hat = self.forward(x)
        val_loss = F.mse_loss(x_hat, y)
        self.log("val_loss", val_loss)
        
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x_hat = self.forward(x)
        test_loss = F.mse_loss(x_hat, y)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    

In [6]:
model=Model(FullModel(3,20))
tb_logger = pl_loggers.TensorBoardLogger(save_dir="./")
trainer = pl.Trainer(log_every_n_steps=10,max_epochs=100,logger=tb_logger,callbacks=[EarlyStopping(monitor='val_loss')],deterministic=True)
data_module=SwissGridDataModule()
trainer.fit(model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name  | Type      | Params
------------------------------------
0 | model | FullModel | 500   
------------------------------------
500       Trainable params
0         Non-trainable params
500       Total params
0.002     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 1:  32%|███▏      | 9/28 [00:02<00:06,  3.13it/s, v_num=39] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [10]:
trainer.test(ckpt_path="best",datamodule=data_module)

Restoring states from the checkpoint path at ./lightning_logs/version_8/checkpoints/epoch=75-step=2128.ckpt
Loaded model weights from the checkpoint at ./lightning_logs/version_8/checkpoints/epoch=75-step=2128.ckpt
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 9/9 [00:01<00:00,  5.70it/s]


[{'test_loss': 0.39482155442237854}]

In [12]:
model2=Model(FullModel2(3,20))
tb_logger = pl_loggers.TensorBoardLogger(save_dir="./")
trainer = pl.Trainer(log_every_n_steps=10,max_epochs=100,logger=tb_logger,callbacks=[EarlyStopping(monitor='val_loss')],)
data_module=SwissGridDataModule()
trainer.fit(model2, datamodule=data_module)
trainer.test(ckpt_path="best",datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | FullModel2 | 500   
-------------------------------------
500       Trainable params
0         Non-trainable params
500       Total params
0.002     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 99: 100%|██████████| 28/28 [00:06<00:00,  4.33it/s, v_num=9]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 28/28 [00:06<00:00,  4.33it/s, v_num=9]


Restoring states from the checkpoint path at ./lightning_logs/version_9/checkpoints/epoch=99-step=2800.ckpt
Loaded model weights from the checkpoint at ./lightning_logs/version_9/checkpoints/epoch=99-step=2800.ckpt
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 9/9 [00:01<00:00,  6.97it/s] 


[{'test_loss': 0.5923712253570557}]