In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import glob
filenames=glob.glob(os.path.join("/kaggle/input/procedural-environment-generation/dataset/dataset/"+"*.png"))
print(len(filenames))
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
pip install torch-summary

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torchsummary import summary
import matplotlib.pyplot as plt
from PIL import Image as I
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms

In [None]:

def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std*eps


class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)


class BetaVAE_H(nn.Module):
    """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

    def __init__(self, z_dim=10, nc=3):
        super(BetaVAE_H, self).__init__()
        self.z_dim = z_dim
        self.nc = nc
        self.encoder = nn.Sequential(
            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),          # B,  64,  8,  8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1),          # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 256, 4, 1),            # B, 256,  1,  1
            nn.ReLU(True),
            View((-1, 256*1*1)),                 # B, 256
            nn.Linear(256, z_dim*2),             # B, z_dim*2
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),               # B, 256
            View((-1, 256, 1, 1)),               # B, 256,  1,  1
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
        )

        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)

    def forward(self, x):
        distributions = self._encode(x)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]
        z = reparametrize(mu, logvar)
        x_recon = self._decode(z)

        return x_recon, mu, logvar

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)


class BetaVAE_B(BetaVAE_H):
    """Model proposed in understanding beta-VAE paper(Burgess et al, arxiv:1804.03599, 2018)."""

    def __init__(self, z_dim=10, nc=1):
        super(BetaVAE_B, self).__init__()
        self.nc = nc
        self.z_dim = z_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32,  8,  8
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32,  4,  4
            nn.ReLU(True),
            View((-1, 32*4*4)),                  # B, 512
            nn.Linear(32*4*4, 256),              # B, 256
            nn.ReLU(True),
            nn.Linear(256, 256),                 # B, 256
            nn.ReLU(True),
            nn.Linear(256, z_dim*2),             # B, z_dim*2
        )

        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),               # B, 256
            nn.ReLU(True),
            nn.Linear(256, 256),                 # B, 256
            nn.ReLU(True),
            nn.Linear(256, 32*4*4),              # B, 512
            nn.ReLU(True),
            View((-1, 32, 4, 4)),                # B,  32,  4,  4
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32,  8,  8
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, nc, 4, 2, 1), # B,  nc, 64, 64
        )
        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)

    def forward(self, x):
        distributions = self._encode(x)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]
        z = reparametrize(mu, logvar)
        x_recon = self._decode(z).view(x.size())

        return x_recon, mu, logvar

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def normal_init(m, mean, std):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        m.weight.data.normal_(mean, std)
        if m.bias.data is not None:
            m.bias.data.zero_()
    elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
        m.weight.data.fill_(1)
        if m.bias.data is not None:
            m.bias.data.zero_()



In [None]:
# class VAE(nn.Module):
#     def __init__(self):
#         super(VAE, self).__init__()

#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), #64,128 128
#             nn.ReLU(),
#             nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), #128, 64 64
#             nn.ReLU(),
#             nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), #256 32 32
#             nn.ReLU(),
#             nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), #256 16 16
#             nn.ReLU(),
#             nn.Flatten()
#         )

#         self.fc_mu = nn.Linear(512*17*17, 64)
#         self.fc_logvar = nn.Linear(512 * 17 * 17, 64)

#         self.decoder = nn.Sequential(
#             nn.Linear(64,512 * 17 * 17),
#             nn.ReLU(),
#             nn.Unflatten(1, (512,17,17)),
#             nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=1), #256 16 16
#             nn.ReLU(),
#             nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=1),  #128 32 32
#             nn.ReLU(),
#             nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=1),  #64 64 64
#             nn.ReLU(),
#             nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2, padding=1), #32 128 128
            
#             nn.Sigmoid()
#         )

#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std

#     def forward(self, x):
#         siz=x.size()
#         x = self.encoder(x)
#         mu = self.fc_mu(x)
#         logvar = self.fc_logvar(x)
#         z = self.reparameterize(mu, logvar)
# #         z = z.view(-1, 512, 1, 1)
#         x_reconstructed = self.decoder(z).view([64,1,254,254])
#         return x_reconstructed, mu, logvar


In [None]:
model=BetaVAE_B(64).to(device)
summary(model)

In [None]:
#Loading images
def load_image(paths):
    for path in paths:
        img=np.array(I.open(path))
        img_normal=(img-np.min(img))/(np.max(img)-np.min(img))
        yield np.transpose(np.expand_dims(np.float32(img_normal), axis = 2),(2,0,1))
        
class Terrains():
    def __init__(self,paths):
        self.paths=paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self,id=None):
        if torch.is_tensor(id):
            id=id.tolist()
        image=next(load_image([self.paths[id]]))
        return image


In [None]:
batch_size=64
train_set=Terrains(filenames)
device='cuda' if torch.cuda.is_available() == True else 'cpu'
device

In [None]:
fig,ax=plt.subplots(5,5,figsize=(14,14))
sample=[next(load_image(filenames)) for i in range(25)]
images=load_image(filenames)
idx=0
for i in range(5):
    for j in range(5):
        ax[i,j].imshow(np.transpose(next(images), (1,2,0)), cmap = 'gray')
        idx+=1

In [None]:
Batches = DataLoader(dataset = train_set, batch_size = batch_size, shuffle = True)

In [None]:
CP_dir = 'CP_VAE'
os.makedirs(CP_dir, exist_ok=True)

In [None]:
next(load_image(filenames)).shape

In [None]:
lr = 3e-4
epochs=200
criterion =nn.MSELoss()#nn.BCELoss()#nn.MSELoss()
# nn.CrossEntropyLoss()
beta1=0.9
beta2=0.999
optimizer= optim.Adam(model.parameters(), lr=lr,betas=(beta1, beta2))


In [None]:
#Train

for epoch in range(epochs):
    for i, data in enumerate(Batches,0):
#         print(data.size())
        inp= data.to(device)
        inputs = inp
        
        # Train VAE
        optimizer.zero_grad()
        reconstructions, mean, log_var = model(inputs)
#         print(reconstructions, inputs)
        
        reconstruction_loss = criterion(reconstructions, inputs)
        
        kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        loss = reconstruction_loss + kl_divergence
        loss.backward()
        optimizer.step()

        if (i + 1) % 64 == 0:  # Adjust the interval based on your needs
            print(f'Epoch [{epoch + 1}/{epochs}], Batch [{i + 1}/{len(Batches)}], Loss: {loss.item()}')
    with torch.no_grad():
        model.eval()
        x = torch.randn(64, 1, 256, 256).to(device)
        y, _, _ = model(x)
        generate_image = y[0][0].cpu().detach()

    generate_image_np = generate_image.numpy().squeeze()

    generate_image_np = (generate_image_np * 255).clip(0, 255) / 255.0
    image_name = f'generate_img_epoch_{epoch + 1}.png'
    image_path = os.path.join(CP_dir, image_name)
    plt.imshow(generate_image_np, cmap='gray')
    plt.imsave(image_path, generate_image_np, cmap='gray')

    plt.title(f'Generated Image - Epoch {epoch + 1}')
    plt.show()

    # model checkpoints
    checkpoint_vae = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint_vae, os.path.join(CP_dir, f'vae_checkpoint_epoch_{epoch + 1}.pth'))

    

In [None]:
x = torch.randn(64,1, 256, 256).to(device)
y, _, _ = model(x)
# m = torch.nn.Upsample(scale_factor=2, mode='nearest')
m=torch.nn.functional.interpolate
y_=m(y,[256,256])
generate_image = y[0][0].cpu().detach()
generate_image_ = y_[0][0].cpu().detach()
generate_image_np = generate_image.numpy().squeeze()
generate_image_np_ = generate_image_.numpy().squeeze()
generate_image_np = (generate_image_np * 255).clip(0, 255) / 255.0
generate_image_np_ = (generate_image_np_ * 255).clip(0, 255) / 255.0
plt.imshow(generate_image_np, cmap='gray')

plt.show()

In [None]:
plt.imshow(generate_image_np_, cmap='gray')