In [None]:
import torch
import torch.nn as nn  
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision.utils import save_image


In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.ImageFolder(root="C:/Users/Arun/pytorch/datasxts/celebA",  transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)


In [None]:
from encoder import Encoder

# Define a simple args class to hold the parameters
class Args:
    def __init__(self, image_channels, latent_dim):
        self.image_channels = image_channels
        self.latent_dim = latent_dim

# Initialize the arguments
args = Args(image_channels=3, latent_dim=64)  # For example, using RGB images and a latent dimension of 64

# Instantiate the Encoder
encoder = Encoder(args)

# Print the model architecture (optional)
print(encoder)

# Create a random input tensor with the shape (batch_size, image_channels, height, width)
batch_size = 4
image_height = 256
image_width = 256
input_tensor = torch.randn(batch_size, args.image_channels, image_height, image_width)

# Forward pass through the encoder
output = encoder(input_tensor)

# Print the output shape
print("Output shape:", output.shape)


Encoder(
  (model): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResBlock(
      (activation): SiLU()
      (blocks): Sequential(
        (0): GroupNorm(32, 128, eps=1e-06, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): GroupNorm(32, 128, eps=1e-06, affine=True)
        (4): SiLU()
        (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (residual): Identity()
    )
    (2): ResBlock(
      (activation): SiLU()
      (blocks): Sequential(
        (0): GroupNorm(32, 128, eps=1e-06, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): GroupNorm(32, 128, eps=1e-06, affine=True)
        (4): SiLU()
        (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (residual): Identity()
    )
    (3): DownsampleBlock(
     

In [None]:
from decoder import Decoder


# Define a simple args class to hold the parameters
class Args:
    def __init__(self, image_channels, latent_dim):
        self.image_channels = image_channels
        self.latent_dim = latent_dim

# Initialize the arguments
args = Args(image_channels=3, latent_dim=64)  # For example, using RGB images and a latent dimension of 64

# Instantiate the Decoder
decoder = Decoder(args)

# Print the model architecture (optional)
print(decoder)

# Create a random input tensor with the shape (batch_size, latent_dim, height, width)
batch_size = 4
latent_height = 16
latent_width = 16
input_tensor = torch.randn(batch_size, args.latent_dim, latent_height, latent_width)

# Forward pass through the decoder
output = decoder(input_tensor)

# Print the output shape
print("Output shape:", output.shape)


Decoder(
  (model): Sequential(
    (0): Conv2d(64, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResBlock(
      (activation): SiLU()
      (blocks): Sequential(
        (0): GroupNorm(32, 512, eps=1e-06, affine=True)
        (1): SiLU()
        (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): GroupNorm(32, 512, eps=1e-06, affine=True)
        (4): SiLU()
        (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (residual): Identity()
    )
    (2): NonLocalBlock(
      (gn): GroupNorm(32, 512, eps=1e-06, affine=True)
      (theta): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (phi): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (g): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (output_conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    )
    (3): ResBlock(
      (activation): SiLU()
      (blocks): Sequential(
        (0): GroupNorm(32, 512, ep

In [None]:
from encoder import Encoder
from decoder import Decoder
from vqema import VectorQuantizerEMA
class Args:
    def __init__(self, image_channels, latent_dim):
        self.image_channels = image_channels
        self.latent_dim = latent_dim

# Initialize the arguments
args = Args(image_channels=3, latent_dim=64)
class VQVAE(nn.Module):
    def __init__(self, in_channels, num_embeddings, embedding_dim, commitment_cost, decay):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(args=args)
        self.vq = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
        self.decoder = Decoder(args=args)

    def forward(self, x):
        z_e = self.encoder(x)
        vq_loss, z_q, _ = self.vq(z_e)
        x_recon = self.decoder(z_q)
        recon_loss = F.mse_loss(x_recon, x)
        loss = recon_loss + vq_loss
        return loss, x_recon

In [None]:
from torchsummary import summary

in_channels = 3
num_embeddings = 512
embedding_dim = 64
commitment_cost = 0.25
decay = 0.99
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models

vqvae = VQVAE(in_channels, num_embeddings, embedding_dim, commitment_cost, decay).to(device)
en = Encoder(args=args).to(device)

summary(en, (3, 256, 256))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 128, 256, 256]           3,584
          Identity-2        [-1, 128, 256, 256]               0
         GroupNorm-3        [-1, 128, 256, 256]             256
              SiLU-4        [-1, 128, 256, 256]               0
              SiLU-5        [-1, 128, 256, 256]               0
            Conv2d-6        [-1, 128, 256, 256]         147,584
         GroupNorm-7        [-1, 128, 256, 256]             256
              SiLU-8        [-1, 128, 256, 256]               0
              SiLU-9        [-1, 128, 256, 256]               0
           Conv2d-10        [-1, 128, 256, 256]         147,584
         ResBlock-11        [-1, 128, 256, 256]               0
         Identity-12        [-1, 128, 256, 256]               0
        GroupNorm-13        [-1, 128, 256, 256]             256
             SiLU-14        [-1, 128, 2

In [None]:
print('Number of params in Model: {}'.format(
    sum(p.data.nelement() for p in vqvae.parameters() if p.requires_grad),
))

Number of params in Model: 62573379


In [8]:
from tqdm import tqdm
import os


# Hyperparameters
in_channels = 3
num_embeddings = 512
embedding_dim = 64
commitment_cost = 0.25
decay = 0.99
epochs = 20
learning_rate = 1e-3
batch_size = 64

model_dir = "checkpoints"
sample_dir = "samples"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate model
vqvae = VQVAE(in_channels, num_embeddings, embedding_dim, commitment_cost, decay).to(device)

# Optimizer
optimizer = optim.Adam(vqvae.parameters(), lr=learning_rate)

# Ensure directories exist
os.makedirs(sample_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

def find_latest_checkpoint(checkpoints_dir):
    checkpoints = [f for f in os.listdir(checkpoints_dir) if f.endswith('.pth')]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[2].split('.')[0]))
    return os.path.join(checkpoints_dir, latest_checkpoint)

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    vqvae.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    return start_epoch

# Load the latest checkpoint if available
start_epoch = 0
latest_checkpoint = find_latest_checkpoint(model_dir)
if latest_checkpoint:
    start_epoch = load_checkpoint(latest_checkpoint)
    print(f'Resuming training from epoch {start_epoch}')
else:
    print(f'Training from Scratch')

# Training loop
for epoch in range(start_epoch, epochs):
    train_loss = 0
    with tqdm(enumerate(train_loader, start=1), total=len(train_loader)) as t:
        for batch_idx, (data, _) in t:
            print(batch_idx)
            data = data.to(device)
            
            optimizer.zero_grad()
            loss, _ = vqvae(data)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            t.set_description(f'Epoch [{epoch+1}/{epochs}]')
            t.set_postfix({'Training loss': f'{train_loss/(batch_idx+1):.3f}', 'Batch loss': f'{loss:.3f}'})
            

    train_loss /= len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{epochs} Average Loss: {train_loss}')

    # Save model checkpoint
    if (epoch + 1) % 5 == 0:
        checkpoint_path = os.path.join(model_dir, f'vqvae_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': vqvae.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}")

    # Save some reconstructed images for visualization
    vqvae.eval()
    with torch.no_grad():
        sample = next(iter(train_loader))[0].to(device)
        _, recon_sample = vqvae(sample)
        recon_sample = recon_sample.cpu()
        save_image(recon_sample, os.path.join(sample_dir, f'epoch_{epoch+1}_reconstructions.png'))
        save_image(sample.cpu(), os.path.join(sample_dir, f'epoch_{epoch+1}_originals.png'))


Training from Scratch


  0%|          | 0/25325 [00:00<?, ?it/s]

1


Epoch [1/20]:   0%|          | 1/25325 [00:16<114:59:43, 16.35s/it, Training loss=0.607, Batch loss=1.214]

2


Epoch [1/20]:   0%|          | 2/25325 [00:28<95:58:56, 13.65s/it, Training loss=1.394, Batch loss=2.970] 

3


Epoch [1/20]:   0%|          | 3/25325 [00:40<90:58:28, 12.93s/it, Training loss=1.580, Batch loss=2.137]

4


Epoch [1/20]:   0%|          | 4/25325 [00:52<87:53:44, 12.50s/it, Training loss=2.088, Batch loss=4.121]

5


Epoch [1/20]:   0%|          | 5/25325 [01:03<86:20:29, 12.28s/it, Training loss=2.067, Batch loss=1.962]

6


Epoch [1/20]:   0%|          | 6/25325 [01:15<85:50:33, 12.21s/it, Training loss=1.969, Batch loss=1.377]

7


Epoch [1/20]:   0%|          | 7/25325 [01:28<86:08:29, 12.25s/it, Training loss=2.004, Batch loss=2.254]

8


Epoch [1/20]:   0%|          | 8/25325 [01:40<86:51:09, 12.35s/it, Training loss=1.934, Batch loss=1.374]

9


Epoch [1/20]:   0%|          | 9/25325 [01:52<85:39:21, 12.18s/it, Training loss=1.798, Batch loss=0.572]

10


Epoch [1/20]:   0%|          | 10/25325 [02:04<85:03:03, 12.09s/it, Training loss=1.699, Batch loss=0.705]

11


Epoch [1/20]:   0%|          | 11/25325 [02:16<84:24:17, 12.00s/it, Training loss=1.781, Batch loss=2.685]

12


Epoch [1/20]:   0%|          | 12/25325 [02:28<83:56:01, 11.94s/it, Training loss=2.206, Batch loss=7.312]

13


Epoch [1/20]:   0%|          | 13/25325 [02:40<83:42:54, 11.91s/it, Training loss=2.623, Batch loss=8.036]

14


Epoch [1/20]:   0%|          | 14/25325 [02:51<83:31:09, 11.88s/it, Training loss=2.902, Batch loss=6.820]

15


Epoch [1/20]:   0%|          | 15/25325 [03:04<84:11:29, 11.98s/it, Training loss=2.953, Batch loss=3.719]

16


Epoch [1/20]:   0%|          | 16/25325 [03:21<96:38:43, 13.75s/it, Training loss=2.916, Batch loss=2.319]

17


Epoch [1/20]:   0%|          | 17/25325 [03:38<101:51:35, 14.49s/it, Training loss=2.847, Batch loss=1.668]

18


Epoch [1/20]:   0%|          | 18/25325 [03:54<105:24:47, 15.00s/it, Training loss=2.769, Batch loss=1.370]

19


Epoch [1/20]:   0%|          | 19/25325 [04:11<110:26:52, 15.71s/it, Training loss=2.698, Batch loss=1.338]

20


Epoch [1/20]:   0%|          | 20/25325 [04:28<112:28:15, 16.00s/it, Training loss=2.619, Batch loss=1.052]

21


Epoch [1/20]:   0%|          | 21/25325 [04:44<112:45:48, 16.04s/it, Training loss=2.538, Batch loss=0.829]

22


Epoch [1/20]:   0%|          | 22/25325 [04:59<111:34:04, 15.87s/it, Training loss=2.472, Batch loss=1.019]

23


Epoch [1/20]:   0%|          | 23/25325 [05:16<112:52:41, 16.06s/it, Training loss=2.410, Batch loss=0.989]

24


Epoch [1/20]:   0%|          | 24/25325 [05:32<112:24:13, 15.99s/it, Training loss=2.353, Batch loss=0.994]

25


Epoch [1/20]:   0%|          | 25/25325 [05:49<115:44:27, 16.47s/it, Training loss=2.314, Batch loss=1.343]

26


Epoch [1/20]:   0%|          | 26/25325 [06:05<114:43:49, 16.33s/it, Training loss=2.291, Batch loss=1.690]

27


Epoch [1/20]:   0%|          | 27/25325 [06:21<114:01:27, 16.23s/it, Training loss=2.267, Batch loss=1.612]

28


Epoch [1/20]:   0%|          | 28/25325 [06:38<115:51:00, 16.49s/it, Training loss=2.244, Batch loss=1.603]

29


Epoch [1/20]:   0%|          | 29/25325 [06:57<119:25:18, 17.00s/it, Training loss=2.217, Batch loss=1.436]

30


Epoch [1/20]:   0%|          | 30/25325 [07:13<119:10:00, 16.96s/it, Training loss=2.184, Batch loss=1.192]

31


Epoch [1/20]:   0%|          | 31/25325 [07:31<120:42:48, 17.18s/it, Training loss=2.147, Batch loss=0.986]

32


Epoch [1/20]:   0%|          | 32/25325 [07:48<120:18:10, 17.12s/it, Training loss=2.103, Batch loss=0.717]

33


Epoch [1/20]:   0%|          | 33/25325 [08:07<124:18:22, 17.69s/it, Training loss=2.057, Batch loss=0.525]

34


Epoch [1/20]:   0%|          | 34/25325 [08:29<132:39:57, 18.88s/it, Training loss=2.011, Batch loss=0.448]

35


Epoch [1/20]:   0%|          | 35/25325 [08:46<129:25:36, 18.42s/it, Training loss=1.967, Batch loss=0.441]

36


Epoch [1/20]:   0%|          | 36/25325 [09:02<124:44:09, 17.76s/it, Training loss=1.926, Batch loss=0.441]

37


Epoch [1/20]:   0%|          | 37/25325 [09:20<123:22:16, 17.56s/it, Training loss=1.888, Batch loss=0.467]

38


Epoch [1/20]:   0%|          | 38/25325 [09:36<120:33:42, 17.16s/it, Training loss=1.854, Batch loss=0.586]

39


Epoch [1/20]:   0%|          | 39/25325 [09:53<120:07:26, 17.10s/it, Training loss=1.819, Batch loss=0.451]

40


Epoch [1/20]:   0%|          | 40/25325 [10:10<120:37:12, 17.17s/it, Training loss=1.784, Batch loss=0.375]

41


Epoch [1/20]:   0%|          | 41/25325 [10:27<120:20:49, 17.14s/it, Training loss=1.750, Batch loss=0.368]

42


Epoch [1/20]:   0%|          | 42/25325 [10:47<126:28:58, 18.01s/it, Training loss=1.717, Batch loss=0.317]

43


Epoch [1/20]:   0%|          | 43/25325 [11:05<125:08:00, 17.82s/it, Training loss=1.686, Batch loss=0.337]

44


Epoch [1/20]:   0%|          | 44/25325 [11:23<126:38:14, 18.03s/it, Training loss=1.657, Batch loss=0.400]

45


Epoch [1/20]:   0%|          | 45/25325 [11:44<131:59:11, 18.80s/it, Training loss=1.634, Batch loss=0.613]

46


Epoch [1/20]:   0%|          | 46/25325 [12:01<129:19:44, 18.42s/it, Training loss=1.618, Batch loss=0.845]

47


Epoch [1/20]:   0%|          | 47/25325 [12:20<129:44:30, 18.48s/it, Training loss=1.609, Batch loss=1.216]

48


Epoch [1/20]:   0%|          | 48/25325 [12:40<133:52:47, 19.07s/it, Training loss=1.608, Batch loss=1.531]

49
