In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.utils import save_image
from tqdm import tqdm
from torch import optim
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DeepSpeedStrategy

torch.set_float32_matmul_precision("medium")

In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, num_workers):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage):
        mnist_full = train_dataset = datasets.MNIST(
            root="dataset/", train=True, transform=transforms.ToTensor(), download=True
        )
        self.mnist_test = datasets.MNIST(
            root="dataset/", train=False, transform=transforms.ToTensor(), download=True
        )
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(
            mnist_full, [55000, 5000]
        )

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.mnist_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.mnist_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            shuffle=False,
        )

module = MNISTDataModule(32, 20)
module.setup("fit")
print(len(module.mnist_train))
print(len(module.mnist_val))
print(len(module.mnist_test))

55000
5000
10000


In [3]:
class VAEpl(pl.LightningModule):
    def __init__(self, lr, input_dim=784, h_dim=200, z_dim=20):
        super().__init__()
        self.lr = lr
        self.loss_fn = nn.BCELoss(reduction="sum")
        self.input_dim = input_dim

        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma * epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(-1, self.input_dim)
        x_reconstructed, mu, sigma = self.forward(x)
        reconstruction_loss = self.loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        loss = reconstruction_loss + kl_div
        self.log("train_loss", loss, sync_dist=True)
        
        if batch_idx % 100 == 0:
            x = x[:8]
            x_reconstructed = x_reconstructed[:8]
            grid = torchvision.utils.make_grid(x_reconstructed.view(-1, 1, 28, 28))
            self.logger.experiment.add_image("reconstructed", grid, self.global_step)
            grid = torchvision.utils.make_grid(x.view(-1, 1, 28, 28))
            self.logger.experiment.add_image("original", grid, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(-1, self.input_dim)
        x_reconstructed, mu, sigma = self.forward(x)
        reconstruction_loss = self.loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        loss = reconstruction_loss + kl_div
        self.log("val_loss", loss, sync_dist=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(-1, self.input_dim)
        x_reconstructed, mu, sigma = self.forward(x)
        reconstruction_loss = self.loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        loss = reconstruction_loss + kl_div
        self.log("test_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

batch_size = 8
x = torch.randn(batch_size, 28 * 28 * 1)
vae_pl = VAEpl(3e-4)
x_reconstructed, mu, sigma = vae_pl(x)
print(x_reconstructed.shape)

torch.Size([8, 784])


In [4]:
def inference(model, dataset, digit, num_examples=1):
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")

In [5]:
lr = 3e-4
batch_size = 32
num_workers = 20
model = VAEpl(lr)
module = MNISTDataModule(batch_size, num_workers)
logger = TensorBoardLogger("my_checkpoint", name="scheduler_autolr_vae_pl_model")

callbacks = [
             pl.callbacks.LearningRateMonitor(logging_interval="step"),
             pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min", save_last=True),
            ]

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    logger=logger,
    precision=16,
    strategy='ddp_notebook'
)

trainer.fit(model, module) 

/home/maxim/anaconda3/lib/python3.11/site-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [6]:
trainer.test(model, module)

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
[W socket.cpp:436] [c10d] The server socket has failed to bind to [::]:46465 (errno: 98 - Address already in use).
[W socket.cpp:436] [c10d] The server socket has failed to bind to 0.0.0.0:46465 (errno: 98 - Address already in use).
[E socket.cpp:472] [c10d] The server socket has failed to listen on any local network address.


ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 170, in _wrapping_function
    results = function(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 795, in _test_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 948, in _run
    self.strategy.setup_environment()
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/strategies/ddp.py", line 146, in setup_environment
    self.setup_distributed()
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/strategies/ddp.py", line 197, in setup_distributed
    _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/lightning_fabric/utilities/distributed.py", line 290, in _init_dist_connection
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 74, in wrapper
    func_return = func(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 1141, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/torch/distributed/rendezvous.py", line 241, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/torch/distributed/rendezvous.py", line 172, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:46465 (errno: 98 - Address already in use). The server socket has failed to bind to 0.0.0.0:46465 (errno: 98 - Address already in use).
