In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
import torchvision.transforms as T

import matplotlib.pyplot as plt
import os
from PIL import Image
import torch.utils.data as dutils
from tqdm import tqdm
import torchvision.utils as vutils




In [18]:


# Define the VAE model
class VAE(nn.Module):
    def __init__(self, image_size=256, latent_dim=100):
        super(VAE, self).__init__()
        self.norm = torch.distributions.Normal(0, 1)
        self.norm.loc = self.norm.loc.cuda()
        self.norm.scale = self.norm.scale.cuda()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )

        # Latent space
        self.fc_mu = nn.Linear(512 * 16*16, latent_dim)
        self.fc_sigma = nn.Linear(512 * 16*16, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512 * (image_size // 16) * (image_size // 16)),
            nn.ReLU(),
            nn.Unflatten(1, (512, image_size // 16, image_size // 16)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)

        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        sigma = torch.exp(self.fc_sigma(x))
        
        return mu, sigma

    def reparameterize(self, mu, sigma):
        
        z = mu + sigma * self.norm.sample(mu.shape)
        
        return z

    def decode(self, z):
        z = self.decoder(z)
        return z

    def forward(self, x):
        mu, sigma= self.encode(x)
        z = self.reparameterize(mu, sigma)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mu, sigma


In [32]:
class Dataset(dutils.Dataset):
    def __init__(self, transform=None, Test=False):  
        self.transform = transform
        if Test:
            self.folder_path='images/test/'
        else:
            self.folder_path='images/train/'
        self.images = [x for x in os.listdir(self.folder_path)]

    def __len__(self):
        return len(self.images)
	
    def __getitem__(self, index):
        image = self.images[index]
        with Image.open(self.folder_path+image) as img:
                if self.transform:
                    img = self.transform(img)
        return img



In [33]:
transform = T.Compose([
	T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
	T.ToTensor()
])


In [34]:
device = (
	'cuda' if torch.cuda.is_available()
	else 'mps' if torch.backends.mps.is_available()
	else 'cpu'
)

In [35]:
train_ds = Dataset(transform=transform)
test_ds = Dataset(transform=transform, Test=True)
train = dutils.DataLoader(train_ds, batch_size=16, shuffle=True)
test = dutils.DataLoader(test_ds, batch_size=16, shuffle=True)



In [6]:
# Hyperparameters
image_size = 256
latent_dim = 512
batch_size = 16
learning_rate = 1e-4
epochs = 100
PATCH_SIZE=32

In [21]:
import torch
from torchsummary import summary


# Assuming FEATURES_G and num_blocks have their default values
model = VAE(image_size=image_size, latent_dim=latent_dim)

# Use a dummy input with the same size as your actual input during training
dummy_input = torch.randn(1, 3, 256, 256)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
dummy_input = dummy_input.to(device)
input_size = tuple(dummy_input.size()[1:])
# Print the model summary
summary(model, input_size=input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             896
       BatchNorm2d-2         [-1, 32, 256, 256]              64
              ReLU-3         [-1, 32, 256, 256]               0
            Conv2d-4         [-1, 64, 128, 128]          18,496
       BatchNorm2d-5         [-1, 64, 128, 128]             128
              ReLU-6         [-1, 64, 128, 128]               0
            Conv2d-7          [-1, 128, 64, 64]          73,856
       BatchNorm2d-8          [-1, 128, 64, 64]             256
              ReLU-9          [-1, 128, 64, 64]               0
           Conv2d-10          [-1, 256, 32, 32]         295,168
      BatchNorm2d-11          [-1, 256, 32, 32]             512
             ReLU-12          [-1, 256, 32, 32]               0
           Conv2d-13          [-1, 512, 16, 16]       1,180,160
      BatchNorm2d-14          [-1, 512,

In [22]:
from torchviz import make_dot



# Visualize the model and save the plot as PNG
dot = make_dot(model(dummy_input), params=dict(model.named_parameters()))
dot.format = 'png'
dot.render(filename='vae', directory='./')

'vae.png'

In [None]:
fix_x_offset = max(0, (256 - PATCH_SIZE) // 2)
fix_y_offset =  fix_x_offset



for i in range(5):
	images = next(iter(test))
	images = torch.reshape(images, (16, 3, 256, 256))

	img_og = images.clone()


	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = 0


	noise = torch.randn(images.shape[0], 256)
	with torch.no_grad():
		predicted_patches = model(images, noise)
	
	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = predicted_patches[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE]

	images.to('cpu')

	img = torch.cat([img_og, images], dim=3)
	img = T.ToPILImage()(vutils.make_grid( img, normalize=True, padding=2, nrow=4))
	img.save(f'test/test_32_{i+1}.jpg')

In [10]:
vae= VAE(image_size=image_size, latent_dim=latent_dim).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)


x_offset = max(0, (image_size - PATCH_SIZE) // 2)
y_offset =  x_offset


hole_size = PATCH_SIZE 
		
fixed_images = next(iter(train))[:16].to(device)


for epoch in range(epochs):
    for image_batch in tqdm(train):

        image_batch = image_batch.to(device)

        # Create a random hole in the input image
         # You can adjust the size of the hole as needed
       
        inputs = image_batch.clone()
        hole_mask = torch.zeros_like(inputs[:, :, x_offset:x_offset+hole_size, y_offset:y_offset+hole_size])
        inputs[:, :, x_offset:x_offset+hole_size, y_offset:y_offset+hole_size] = hole_mask


        optimizer.zero_grad()
        outputs, mu, sigma = vae(inputs)

       
        loss_reconstruction = criterion(outputs[:, :, x_offset:x_offset+PATCH_SIZE, y_offset:y_offset+PATCH_SIZE], image_batch[:, :, x_offset:x_offset+PATCH_SIZE, y_offset:y_offset+PATCH_SIZE].detach())

        # KL divergence loss
        kl = (sigma ** 2 + mu ** 2 - torch.log(sigma+1e-8) - 0.5).sum()
        
        loss = loss_reconstruction +kl
       
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Visualize progress (optional)
    fixed_images[:, :, x_offset:x_offset+PATCH_SIZE, y_offset:y_offset+PATCH_SIZE] = 0.0
    with torch.no_grad():
        predicted_patches, _,_ = vae(fixed_images)
    fixed_images[:, :, x_offset:x_offset+PATCH_SIZE, y_offset:y_offset+PATCH_SIZE] = predicted_patches[:, :, x_offset:x_offset+PATCH_SIZE, y_offset:y_offset+PATCH_SIZE]
    img = T.ToPILImage()(vutils.make_grid(fixed_images.to('cpu'), normalize=True, padding=2, nrow=4))
    img.save(f'progressvae/epoch_{epoch}.jpg')


vae = vae.to('cpu')
torch.save(vae, 'models/vae_inpainting_model.pkl')

KeyboardInterrupt: 

In [24]:
vae=torch.load('models/vae_inpainting_model.pkl')

In [38]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
import numpy as np
ssims = list()
mses = list()

for i in range(100):
	images = next(iter(test))
	images = torch.reshape(images, (16, 3, 256, 256))

	img_og = images.clone()


	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = 0


	with torch.no_grad():
		predicted_patches, _,_ = vae(images)

	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = predicted_patches[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE]

	images.to('cpu')

	for j in range(images.shape[0]):
		im1 = np.array(images[j])
		im2 = np.array(img_og[j])
		ssims.append(ssim(im1, im2, channel_axis=0))
		mses.append(mean_squared_error(im1.reshape(-1), im2.reshape(-1)))



AttributeError: module 'PIL.ImageFile' has no attribute 'PyEncoder'