In [1]:
import torch 
import pytorch_lightning as pl
from vq_vae import VQVAE
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from vq_loss import VQLoss
from torchvision.transforms import ToTensor

# load mnist
train_ds = MNIST('data', train=True, download=True, transform=ToTensor())
test_ds = MNIST('data', train=False, download=True , transform=ToTensor())

# create data loaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=12)
val_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=12)

sample = next(iter(train_loader))

In [2]:
class VQVAE_trainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = VQVAE(1, 36, 64)
        self.loss = VQLoss()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, quantized, embedding_indices = self.model(x)
        loss = self.loss(quantized, embedding_indices, x, x_hat)[0]
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, quantized, embedding_indices = self.model(x)
        loss = self.loss(quantized, embedding_indices, x, x_hat)[0]
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [6]:
model = VQVAE_trainer()

sample, _ = next(iter(train_loader))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sample = sample.to(device)

model = model.to(device)

res = model(sample)

print(res[0].shape)



torch.Size([32, 1, 28, 28])


In [3]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints',
    filename='vq_vae-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

model = VQVAE_trainer()

trainer = pl.Trainer(accelerator='gpu', devices=[0], max_epochs=50, enable_progress_bar=True)
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | VQVAE  | 9.4 M 
1 | loss  | VQLoss | 0     
---------------------------------
9.4 M     Trainable params
0         Non-trainable params
9.4 M     Total params
37.775    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (7) must match the size of tensor b (28) at non-singleton dimension 3