### Weird Behaviour of Pytorch Lightning Module

From the transformer translation example provided by pytorch lightning, I found that the lightning module only trains when following the convention in the example - by passing in an intiialized Seq2Seq model into the lightning module through its constructor.
On the other hand, when I tried to intialize the Seq2Seq model within the constructor, it doesn't update at all during training - as observed by stagnant loss and metrics.

This notebook is intended as an attempt to create a minimal reproducible example on the latest versions of Pytorch (2.0.1) and Pytorch Lightning (2.0.4)

In [1]:
import torch
import pytorch_lightning as pl

In [2]:
class PytorchRegressorModule(torch.nn.Module):
    
	def __init__(self):
		super().__init__()
		self.linear = torch.nn.Linear(1, 1)
	
	def forward(self, x):
		return self.linear(x)

In [3]:
class DivideBy2Dataset(torch.utils.data.Dataset):
	
	def __init__(self, x):
		self.x = x
	
	def __len__(self):
		return len(self.x)
	
	def __getitem__(self, idx):
		return self.x[idx], self.x[idx] / 2

In [25]:
class LightningRegressionTask(pl.LightningModule):
    
	def __init__(self, model):
		super().__init__()
		self.model = model
		self.loss_fn = torch.nn.MSELoss()
	
	def forward(self, x):
		return self.model(x)
	
	def training_step(self, batch, batch_idx):
		x, y = batch
		y_hat = self.model(x)
		loss = self.loss_fn(y_hat, y)
		self.log('train_loss', loss, prog_bar=True)
		return loss
	
	def validation_step(self, batch, batch_idx):
		x, y = batch
		y_hat = self.model(x)
		loss = self.loss_fn(y_hat, y)
		self.log('val_loss', loss, prog_bar=True)
	
	def configure_optimizers(self):
		return torch.optim.Adam(self.model.parameters(), lr=1e-2)

In [26]:
x = torch.ones((32 * 20, 1), dtype=torch.float32, device='cpu')
dataset = DivideBy2Dataset(x)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

In [27]:
pytorch_model = PytorchRegressorModule()
lightning_model = LightningRegressionTask(pytorch_model)

In [28]:
trainer = pl.Trainer(accelerator='cpu')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [29]:
trainer.fit(lightning_model, dataloader)


  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | PytorchRegressorModule | 2     
1 | loss_fn | MSELoss                | 0     
---------------------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [30]:
list(lightning_model.model.parameters())

[Parameter containing:
 tensor([[0.1580]], requires_grad=True),
 Parameter containing:
 tensor([0.3420], requires_grad=True)]

In [16]:
class LightningRegressionTaskAlt(pl.LightningModule):
    
	def __init__(self):
		super().__init__()
		self.model = PytorchRegressorModule()
		self.loss_fn = torch.nn.MSELoss()
	
	def forward(self, x):
		return self.model(x)
	
	def training_step(self, batch, batch_idx):
		x, y = batch
		y_hat = self.model(x)
		loss = self.loss_fn(y_hat, y)
		self.log('train_loss', loss, prog_bar=True)
		return loss
	
	def validation_step(self, batch, batch_idx):
		x, y = batch
		y_hat = self.model(x)
		loss = self.loss_fn(y_hat, y)
		self.log('val_loss', loss, prog_bar=True)
	
	def configure_optimizers(self):
		return torch.optim.Adam(self.parameters(), lr=1e-2)

In [17]:
lightning_model_alt = LightningRegressionTaskAlt()

In [18]:
trainer = pl.Trainer(accelerator='cpu')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [19]:
trainer.fit(lightning_model_alt, dataloader)

  rank_zero_warn(
  rank_zero_warn(

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | PytorchRegressorModule | 2     
1 | loss_fn | MSELoss                | 0     
---------------------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
