In [1]:
import lightning as L
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
from lightning import Trainer
from icecream import ic


cuda = torch.cuda.is_available()
if cuda:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [2]:
train_x = torch.load('data/train_x.pt').to(torch.float32)
train_y = torch.load('data/train_y.pt').to(torch.float32)
test_x  =  torch.load('data/test_x.pt').to(torch.float32)
test_y  =  torch.load('data/test_y.pt').to(torch.float32)

train_dataset = TensorDataset(train_x, train_y)
test_dataset = TensorDataset(test_x, test_y)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=True, num_workers=0, pin_memory=True)

train_x.shape, test_x.shape, train_y.shape, test_y.shape

(torch.Size([100000, 100]),
 torch.Size([10000, 100]),
 torch.Size([100000]),
 torch.Size([10000]))

In [14]:
class MyModel(L.LightningModule):
    def __init__(self, partition_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(partition_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.loss_fn = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=5e-3)


In [15]:
model = MyModel(train_x.shape[-1])
trainer = Trainer(max_epochs=100, accelerator="auto", devices=1)
trainer.test(model, dataloaders=test_loader)
# trainer.fit(model, train_dataloaders=train_loader)
trainer.fit(model, train_dataloaders=test_loader)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 116.86it/s]


  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | model   | Sequential | 6.5 K  | train
1 | loss_fn | MSELoss    | 0      | train
-----------------------------------------------
6.5 K     Trainable params
0         Non-trainable params
6.5 K     Total params
0.026     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            8.14783000946045
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Epoch 99: 100%|██████████| 10/10 [00:00<00:00, 28.27it/s, v_num=23, train_loss=1.050]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 10/10 [00:00<00:00, 27.94it/s, v_num=23, train_loss=1.050]


In [16]:
trainer.test(model, dataloaders=test_loader)

Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 108.23it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            1.050553560256958
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.050553560256958}]