-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Description
@mauvilsa I think we have a different issue than what #21246 fixes, where we want to load the state dict in a different lightning module. E.g. in the following we will have lr
saved as a hyperparameter in the checkpoint from TrainingModule
, but the InferenceModule
does not take it as an argument:
import torch
from lightning.pytorch import LightningModule
class TrainingModule(LightningModule):
def __init__(self, lr: float = 1e-3) -> None:
super().__init__()
self.model = torch.nn.Linear(16, 2)
self.lr = lr
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def train_dataloader(self) -> torch.utils.data.DataLoader:
dataset = torch.utils.data.TensorDataset(
torch.rand(32, 16), torch.randint(0, 2, (32,))
)
return torch.utils.data.DataLoader(dataset, batch_size=8)
class InferenceModule(LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(16, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def predict_dataloader(self) -> torch.utils.data.DataLoader:
dataset = torch.utils.data.TensorDataset(torch.rand(32, 16))
return torch.utils.data.DataLoader(dataset, batch_size=128)
def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
x = batch
y_hat = self(x)
return y_hat
Originally posted by @ziw-liu in #21116 (comment)
Metadata
Metadata
Assignees
Labels
No labels