Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/integration-pytorch-lightning #480

Merged
merged 5 commits into from
Apr 28, 2024
Merged

Conversation

Zeyi-Lin
Copy link
Member

@Zeyi-Lin Zeyi-Lin commented Apr 18, 2024

Description

与PyTorch Lightning库的集成,测试代码为:

from swanlab.integration.pytorch_lightning import SwanLabLogger

import importlib.util
import os

if importlib.util.find_spec("lightning"):
    import lightning.pytorch as pl
elif importlib.util.find_spec("pytorch_lightning"):  # noqa F401
    import pytorch_lightning as pl
else:
    raise RuntimeError(
        "This contrib module requires PyTorch Lightning to be installed. "
        "Please install it with command: \n pip install pytorch-lightning \n"
        "or \n pip install lightning"
    )
from torch import nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)


# setup data
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
test_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

train_loader = utils.data.DataLoader(train_dataset)
val_loader = utils.data.DataLoader(val_dataset)
test_loader = utils.data.DataLoader(test_dataset)

swanlab_logger = SwanLabLogger(
    project="swanlab_example",
    experiment_name="example_experiment",
    cloud=False,
)

trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, logger=swanlab_logger)


trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(dataloaders=test_loader)

Closes: #478

@Zeyi-Lin Zeyi-Lin self-assigned this Apr 18, 2024
@Zeyi-Lin Zeyi-Lin added this to the Integration milestone Apr 18, 2024
@Zeyi-Lin Zeyi-Lin added the 💪 enhancement New feature or request label Apr 18, 2024
@SAKURA-CAT SAKURA-CAT merged commit 1584f95 into main Apr 28, 2024
@SAKURA-CAT SAKURA-CAT deleted the feat/integration-lightning branch April 28, 2024 06:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
💪 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[REQUEST] 集成PyTorch Lightning
2 participants