In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

class VDVAE(nn.Module):
    def __init__(self, latent_dim=256):
        super(VDVAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),     # 128x128
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),    # 64x64
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),   # 32x32
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 16x16
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # 8x8
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 4, stride=2, padding=1),  # 4x4
            nn.LeakyReLU(0.2)
        )

        # Bottleneck
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(512 * 4 * 4, latent_dim)

        # Decoder input
        self.decoder_input = nn.Linear(latent_dim, 512 * 4 * 4)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),  # 8x8
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 16x16
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 32x32
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),   # 64x64
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),    # 128x128
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),     # 256x256
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 512, 4, 4)
        return self.decoder(x)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

try:
    # Setup
    folder_path = r"C:\Users\sOrOush\SoroushProjects\00_scratchpad\images\processed_images"
    files = [f for f in os.listdir(folder_path)
             if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    # Parameters
    batch_size = 16
    device = torch.device("cuda")

    # Transform for 256x256 images
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    print("1. Loading images...")
    images = []
    for file in files:
        img_path = os.path.join(folder_path, file)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img)
        images.append(img_tensor)

    all_images = torch.stack(images)
    dataset = TensorDataset(all_images)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"Loaded {len(all_images)} images")

    # Create and train model
    model = VDVAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    print("2. Training...")
    num_epochs = 1000
    for epoch in range(num_epochs):
        total_loss = 0
        total_recon_loss = 0
        total_kl_loss = 0

        for batch_idx, (data,) in enumerate(dataloader):
            data = data.to(device)
            optimizer.zero_grad()

            # Forward pass
            recon_batch, mu, log_var = model(data)

            # Loss calculation
            recon_loss = F.mse_loss(recon_batch, data, reduction='sum')
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            loss = recon_loss + 0.01 * kl_loss

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kl_loss += kl_loss.item()

        if (epoch + 1) % 10 == 0:
            avg_loss = total_loss / len(dataloader.dataset)
            avg_recon = total_recon_loss / len(dataloader.dataset)
            avg_kl = total_kl_loss / len(dataloader.dataset)
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Avg Loss: {avg_loss:.4f}, '
                  f'Avg Recon Loss: {avg_recon:.4f}, '
                  f'Avg KL Loss: {avg_kl:.4f}')

    print("3. Generating and saving reconstructions...")
    model.eval()
    output_dir = 'vae_output_highres'
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        # Iterate through the dataloader until at least 20 pairs are saved
        pair_count = 0
        for data_batch in dataloader:
            test_batch = data_batch[0].to(device)
            reconstructions = model(test_batch)[0]

            batch_size = test_batch.size(0)
            for i in range(batch_size):
                try:
                    # Save original
                    orig_img = test_batch[i].cpu().numpy().transpose(1, 2, 0)
                    orig_img = ((orig_img + 1) * 127.5).astype(np.uint8)
                    orig_pil = Image.fromarray(orig_img)
                    orig_path = os.path.join(output_dir, f'original_{pair_count}.png')
                    orig_pil.save(orig_path)
                    print(f"Saved original image {pair_count} to {orig_path}")

                    # Save reconstruction
                    recon_img = reconstructions[i].cpu().numpy().transpose(1, 2, 0)
                    recon_img = ((recon_img + 1) * 127.5).astype(np.uint8)
                    recon_pil = Image.fromarray(recon_img)
                    recon_path = os.path.join(output_dir, f'reconstructed_{pair_count}.png')
                    recon_pil.save(recon_path)
                    print(f"Saved reconstructed image {pair_count} to {recon_path}")

                    pair_count += 1

                    # Stop after saving 20 pairs
                    if pair_count >= 20:
                        break

                except Exception as e:
                    print(f"Error saving image {pair_count}: {e}")

            if pair_count >= 20:
                break

    print("4. Process completed!")

    # print("3. Generating and saving reconstructions...")
    # model.eval()
    # output_dir = 'vae_output_highres'
    # os.makedirs(output_dir, exist_ok=True)

    # with torch.no_grad():
    #     # Get first batch
    #     test_batch = next(iter(dataloader))[0].to(device)
    #     reconstructions = model(test_batch)[0]

    #     # Save first 5 pairs
    #     for i in range(5):
    #         try:
    #             # Save original
    #             orig_img = test_batch[i].cpu().numpy().transpose(1, 2, 0)
    #             orig_img = ((orig_img + 1) * 127.5).astype(np.uint8)
    #             orig_pil = Image.fromarray(orig_img)
    #             orig_path = os.path.join(output_dir, f'original_{i}.png')
    #             orig_pil.save(orig_path)
    #             print(f"Saved original image {i} to {orig_path}")

    #             # Save reconstruction
    #             recon_img = reconstructions[i].cpu().numpy().transpose(1, 2, 0)
    #             recon_img = ((recon_img + 1) * 127.5).astype(np.uint8)
    #             recon_pil = Image.fromarray(recon_img)
    #             recon_path = os.path.join(output_dir, f'reconstructed_{i}.png')
    #             recon_pil.save(recon_path)
    #             print(f"Saved reconstructed image {i} to {recon_path}")

    #         except Exception as e:
    #             print(f"Error saving image {i}: {e}")

    # print("4. Process completed!")

    # Clean up
    del model, reconstructions, test_batch
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Error occurred: {e}")
    import traceback
    traceback.print_exc()

1. Loading images...
Loaded 100 images
2. Training...
Epoch [10/1000], Avg Loss: 37619.8031, Avg Recon Loss: 37616.0484, Avg KL Loss: 375.4036
Epoch [20/1000], Avg Loss: 28992.0812, Avg Recon Loss: 28982.5747, Avg KL Loss: 950.6728
Epoch [30/1000], Avg Loss: 24842.4440, Avg Recon Loss: 24821.7477, Avg KL Loss: 2069.6466
Epoch [40/1000], Avg Loss: 21714.9682, Avg Recon Loss: 21670.1356, Avg KL Loss: 4483.2349
Epoch [50/1000], Avg Loss: 18911.2373, Avg Recon Loss: 18846.3570, Avg KL Loss: 6488.0650
Epoch [60/1000], Avg Loss: 17201.5645, Avg Recon Loss: 17123.7127, Avg KL Loss: 7785.1626
Epoch [70/1000], Avg Loss: 15284.3010, Avg Recon Loss: 15192.6783, Avg KL Loss: 9162.2698
Epoch [80/1000], Avg Loss: 14147.1512, Avg Recon Loss: 14052.5613, Avg KL Loss: 9458.9909
Epoch [90/1000], Avg Loss: 13396.8391, Avg Recon Loss: 13299.1347, Avg KL Loss: 9770.4260
Epoch [100/1000], Avg Loss: 12303.8392, Avg Recon Loss: 12198.5276, Avg KL Loss: 10531.1629
Epoch [110/1000], Avg Loss: 11676.8538, Avg Re

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset

# VDVAE Definition
class VDVAE(nn.Module):
    def __init__(self, latent_dim=256):
        super(VDVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(512 * 4 * 4, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, 512 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc_mu(x), self.fc_var(x)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 512, 4, 4)
        return self.decoder(x)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

# Load images from folder
def load_images(folder_path, image_size=(256, 256)):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,) * 3, (0.5,) * 3)
    ])
    images = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
            path = os.path.join(folder_path, filename)
            img = Image.open(path).convert('RGB')
            images.append(transform(img))
    return torch.stack(images)

# Save image tensor as PNG
def save_image(tensor, path):
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    img = ((img + 1) * 127.5).astype(np.uint8)
    Image.fromarray(img).save(path)

# Train function
def train(model, dataloader, optimizer, device, epochs=100):
    model.train()
    for epoch in range(epochs):
        total_loss, total_recon, total_kl = 0, 0, 0
        for data, in dataloader:
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, log_var = model(data)
            recon_loss = F.mse_loss(recon, data, reduction='sum')
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            loss = recon_loss + 0.01 * kl_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()
        if (epoch + 1) % 10 == 0:
            n = len(dataloader.dataset)
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Loss: {total_loss/n:.2f} | "
                  f"Recon: {total_recon/n:.2f} | "
                  f"KL: {total_kl/n:.2f}")

# Main
def main():
    folder_path = r"C:\Users\sOrOush\SoroushProjects\00_scratchpad\images\processed_images"
    output_dir = "vae_output_highres"
    model_path = "vae_minimal.pth"
    batch_size = 16
    epochs = 2

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Loading images...")
    images = load_images(folder_path)
    dataloader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=True)
    print(f"{len(images)} images loaded.")

    model = VDVAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("Starting training...")
    train(model, dataloader, optimizer, device, epochs=epochs)

    print(f"Saving model to {model_path}...")
    torch.save(model.state_dict(), model_path)

    print(f"Saving reconstructions to '{output_dir}'...")
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    with torch.no_grad():
        pair_count = 0
        for data, in dataloader:
            data = data.to(device)
            recon = model(data)[0]
            for i in range(data.size(0)):
                save_image(data[i], os.path.join(output_dir, f"original_{pair_count}.png"))
                save_image(recon[i], os.path.join(output_dir, f"reconstructed_{pair_count}.png"))
                pair_count += 1
                if pair_count >= 20:
                    break
            if pair_count >= 20:
                break
    print("Done!")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        import traceback
        print("An error occurred:")
        traceback.print_exc()


Loading images...
6827 images loaded.
Starting training...
Saving model to vae_minimal.pth...
Saving reconstructions to 'vae_output_highres'...
Done!


In [2]:
pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Collecting graphviz (from torchviz)
  Downloading graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading torchviz-0.0.3-py3-none-any.whl (5.7 kB)
Downloading graphviz-0.20.3-py3-none-any.whl (47 kB)
Installing collected packages: graphviz, torchviz
Successfully installed graphviz-0.20.3 torchviz-0.0.3
Note: you may need to restart the kernel to use updated packages.


In [2]:
 from torchviz import make_dot

model = VDVAE()
dummy_input = torch.randn(1, 3, 256, 256)
output, mu, logvar = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render('vae_architecture')


ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

In [3]:
import graphviz
print(graphviz.backend.viewing._viewers['windows'])


AttributeError: module 'graphviz.backend.viewing' has no attribute '_viewers'

In [8]:

# pick GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
from torchviz import make_dot

# ... after you’ve built & moved your model to device:
model = VDVAE().to(device)
model.eval()

# create a dummy input the same size as your training images:
dummy = torch.randn(1, 3, 256, 256).to(device)

# forward-pass through the network:
with torch.no_grad():
    recon, mu, log_var = model(dummy)

# build the graph; here we visualize the reconstruction branch:
dot = make_dot(recon, params=dict(model.named_parameters()),
                show_attrs=False,    # set to True to inspect saved tensors
                show_saved=False)    # set to True to see backward nodes too

# choose output format and render
dot.format = 'png'
dot.render('vdvae_graph')


ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH