In [42]:
import  os 
from typing import Optional
import torch
from pytorch_lightning import LightningDataModule,LightningModule,cli_lightning_logo
from pytorch_lightning.cli import LightningCLI
from torchvision.datasets import MNIST
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
import  pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader,random_split

In [3]:
if _TORCHVISION_AVAILABLE:
    from torchvision import transforms

In [5]:
class Backbone(torch.nn.Module):
    def __init__(self,hidden_dim=128):
        super().__init__()
        self.l1=torch.nn.Linear(28*28,hidden_dim)
        self.l2=torch.nn.Linear(hidden_dim,10)
    def forward(self,x):
        x=x.view(x.size(0),-1)
        x=torch.relu(self.l1(x))
        return torch.relu(self.l2(x))
    

In [29]:
from typing import Any


class LitClassifier(LightningModule):
    def __init__(self,backbone:Optional[Backbone]=None,learning_rate:float=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=["backbone"])
        if backbone is None:
            backbone=Backbone()
        self.backbone=backbone
    def forward(self,x):
        return self.backbone(x)
    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        self.log("train_loss",loss,on_epoch=True,prog_bar=True)
        return loss
    def validation_step(self, batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        self.log("val_loss",loss,on_epoch=True,prog_bar=True)
        return loss
    def test_step(self, batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        self.log("test_loss",loss,on_epoch=True,prog_bar=True)
    def predict_step(self,batch,batch_idx,dataloader_idx=None):
        x,_=batch
        return self(x)
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=self.hparams.learning_rate)
model=LitClassifier()

In [48]:
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS


class MyDataModule(LightningDataModule):
    def __init__(self,batch_size:int=32):
        super().__init__()
        dataset=MNIST(os.getcwd(),train=True,download=False,transform=transforms.ToTensor())
        self.mnist_test=MNIST(os.getcwd(),train=False,download=False,transform=transforms.ToTensor())
        self.mnist_train,self.mnist_val=random_split(
            dataset,[55000,5000],generator=torch.Generator().manual_seed(42)
        )
        self.batch_size=batch_size
    def train_dataloader(self):
        return DataLoader(self.mnist_train,batch_size=self.batch_size)
    def val_dataloader(self):
        return DataLoader(self.mnist_val,batch_size=self.batch_size)
    def test_dataloader(self):
        return DataLoader(self.mnist_test,batch_size=self.batch_size)
    def predict_dataloader(self):
        return DataLoader(self.mnist_test,batch_size=self.batch_size)
mydata=MyDataModule(128)

In [47]:
trainer=pl.Trainer(max_epochs=5)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [49]:
trainer.fit(model=model,train_dataloaders=mydata.train_dataloader())

/root/miniconda3/envs/dsi_exp/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name     | Type     | Params
--------------------------------------
0 | backbone | Backbone | 101 K 
--------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/root/miniconda3/envs/dsi_exp/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Epoch 4: 100%|██████████| 430/430 [00:05<00:00, 79.55it/s, v_num=11, train_loss_step=0.270, train_loss_epoch=0.246] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 430/430 [00:05<00:00, 79.46it/s, v_num=11, train_loss_step=0.270, train_loss_epoch=0.246]
