In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import glob
filenames=glob.glob(os.path.join("/kaggle/input/procedural-environment-generation/dataset/dataset/"+"*.png"))
print(len(filenames))
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
pip install torch-summary

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torchsummary import summary
import matplotlib.pyplot as plt
from PIL import Image as I
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
from tqdm import tqdm

In [None]:
#Loading images
def load_image(paths):
    for path in paths:
        img=np.array(I.open(path))
        img_normal=(img-np.min(img))/(np.max(img)-np.min(img))
        yield np.transpose(np.expand_dims(np.float32(img_normal), axis = 2),(2,0,1))
        
class Terrains():
    def __init__(self,paths):
        self.paths=paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self,id=None):
        if torch.is_tensor(id):
            id=id.tolist()
        image=next(load_image([self.paths[id]]))
        return image
next(load_image(filenames)).shape


In [None]:
batch_size=45

device='cuda' if torch.cuda.is_available() == True else 'cpu'
if device=='cuda':
    dev1='cuda:0'
    dev2='cuda:1'

In [None]:
fig,ax=plt.subplots(5,5,figsize=(14,14))
sample=[next(load_image(filenames)) for i in range(25)]
images=load_image(filenames)
idx=0
for i in range(5):
    for j in range(5):
        ax[i,j].imshow(np.transpose(next(images), (1,2,0)), cmap = 'gray')
        idx+=1

In [None]:
train_set=Terrains(filenames)

In [None]:
Batches = DataLoader(dataset = train_set, batch_size = batch_size, shuffle = True) #var=0.0572

In [None]:
len(Batches.dataset)

In [None]:
mean=0
for i in range(84422):
    mean+=np.mean(Batches.dataset.__getitem__(i))
print(mean/84422)

In [None]:
var=0
for i in range(84422):
    var+=np.var(Batches.dataset.__getitem__(i))
print(var)

In [None]:
def mean_dat(data_loader):
    total_samples = 0
    mean = 0.
    mean_sq = 0.
    batch_mean=0.
    # Iterate over batches
    for batch in data_loader:
        images, _ = batch  # Assuming the batch is a tuple of (images, labels)
        batch_size = images.size(0)
        # print(images.size())
        total_samples += batch_size

        # Compute mean and mean of squares for the current batch
        # for i in range(batch_size):
        batch_mean = torch.mean(images, dim=(0, 2,3))  # Mean along batch, height, and width dimensions

        # Update overall mean and mean of squares
        mean += batch_mean* batch_size
        # print(mean)

    # Finalize mean and mean of squares
    mean /= total_samples
    return mean

In [None]:
CP_dir = 'CP_VQ_VAE'
os.makedirs(CP_dir, exist_ok=True)

In [None]:

class GroupNorm(nn.Module):
    def __init__(self, in_channels):
        super(GroupNorm, self).__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

    def forward(self, x):
        return self.gn(x)

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = GroupNorm(in_channels)
        self.q = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.k = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.v = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)

    def forward(self, x):
        h_ = self.norm(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape

        q = q.reshape(b, c, h * w)
        q = q.permute(0, 2, 1)
        k = k.reshape(b, c, h * w)
        v = v.reshape(b, c, h * w)

        attn = torch.bmm(q, k)
        attn = attn * (int(c) ** (-0.5))
        attn = F.softmax(attn, dim=2)

        attn = attn.permute(0, 2, 1)
        A = torch.bmm(v, attn)
        A = A.reshape(b, c, h, w)

        A = self.proj_out(A)

        return x + A
class UpSampleBlock(nn.Module):
    def __init__(self, channels):
        super(UpSampleBlock, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.)
        return self.conv(x)
class DownSampleBlock(nn.Module):
    def __init__(self, channels):
        super(DownSampleBlock, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 2, 0)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = F.pad(x, pad, mode="constant", value=0)
        return self.conv(x)

In [None]:
class Codebook(nn.Module):
    """
    Codebook mapping: takes in an encoded image and maps each vector onto its closest codebook vector.
    Metric: mean squared error = (z_e - z_q)**2 = (z_e**2) - (2*z_e*z_q) + (z_q**2)
    """

    def __init__(self):
        super().__init__()
        self.num_codebook_vectors = 256
        self.latent_dim = 64
        self.beta = 0.25

        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

    def forward(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()  # moving average instead of hard codebook remapping

        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss

In [None]:
class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE,self).__init__()
        channels=[64,128,128,256,256,512]
        hid_dim=256
        res=256
        
        layers_en=[nn.Conv2d(1,channels[0],3,1,1)]
        for i in range(len(channels)-1):
            ich=channels[i]
            och=channels[i+1]
            layers_en.append(nn.Conv2d(ich,och,3,1,1))
            layers_en.append(nn.ReLU())
            ich=och
            if (res==16):
                layers_en.append(NonLocalBlock(ich))
            if i != len(channels) - 2:
                layers_en.append(DownSampleBlock(channels[i+1]))
                res //= 2
        layers_en.append(nn.Sigmoid())
        layers_en.append(nn.Conv2d(channels[-1], 64, 3, 1, 1))
        block_in=channels[-1]
        layers_de=[nn.Conv2d(64, block_in, kernel_size=3, stride=1, padding=1),
                    NonLocalBlock(block_in)]

        for i in reversed(range(len(channels)-1)):
            block_out=channels[i]
            layers_de.append(nn.Conv2d(block_in,block_out,3,1,1))
            block_in=block_out
            if i!=0:
                layers_de.append(UpSampleBlock(block_in))
                res*=2
        layers_de.append(nn.Conv2d(block_in, 1, kernel_size=3, stride=1, padding=1))  
            
        self.encoder=nn.Sequential(*layers_en)  # 512,15,15
        self.decoder=nn.Sequential(*layers_de)
        self.codebook=Codebook()
        self.quant_conv = nn.Conv2d(64, 64, 1)
        self.post_quant_conv = nn.Conv2d(64, 64, 1)
    def forward(self, imgs):
        encoded_images = self.encoder(imgs)
        quantized_encoded_images = self.quant_conv(encoded_images)
        codebook_mapping, codebook_indices, q_loss = self.codebook(quantized_encoded_images)
        quantized_codebook_mapping = self.post_quant_conv(codebook_mapping)
        decoded_images = self.decoder(quantized_codebook_mapping)
        return decoded_images, codebook_indices, q_loss

        

In [None]:
model=VQVAE().to(device)
summary(model)

In [None]:
#working
# for i, data in enumerate(Batches,0):
#         print(data.size())
# inp= data.to(device)
sample_input = torch.randn(45,1, 256,256).to(device)

# Forward pass to obtain the output
y = model(sample_input)
generate_image = y[0][0].cpu().detach()

generate_image_np = generate_image.numpy().squeeze()

generate_image_np = (generate_image_np * 255).clip(0, 255) / 255.0
# image_name = f'generate_img_epoch_{epoch + 1}.png'
# image_path = os.path.join(CP_dir, image_name)
plt.imshow(generate_image_np, cmap='gray')
print(y[0].shape)
# break

In [None]:
lr = 1e-3
epochs=70
beta1=0.0
beta2=0.999
criterion=nn.MSELoss()
opt_vq = torch.optim.Adam(model.parameters(),#list(model.encoder.parameters()) +
#                           list(model.decoder.parameters()) +
#                           list(model.codebook.parameters()) +
#                           list(model.quant_conv.parameters()) +
#                           list(model.post_quant_conv.parameters()),
                          lr=lr, eps=1e-08, betas=(beta1,beta2))


In [None]:
#training
for epoch in range(epochs):
    for i, data in enumerate(Batches,0):
#         print(data.size())
        inp= data.to(device)
        inputs = inp
        
        # Train VAE
        opt_vq.zero_grad()
        reconstructions, _, q_loss = model(inputs)
#         print(reconstructions, inputs)
        
        reconstruction_loss = criterion(reconstructions, inputs)
        
#         kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        loss = reconstruction_loss +q_loss# kl_divergence
        loss.backward()
        opt_vq.step()

        if (i + 1) % 50 == 0:  # Adjust the interval based on your needs
            print(f'Epoch [{epoch + 2}/{epochs}], Batch [{i + 1}/{len(Batches)}], Loss: {loss.item()}')
    with torch.no_grad():
        model.eval()
        x = torch.randn(1, 1, 256, 256).to(device)
        y, _, _ = model(x)
        generate_image = y[0][0].cpu().detach()

    generate_image_np = generate_image.numpy().squeeze()

    generate_image_np = (generate_image_np * 255).clip(0, 255) / 255.0
    image_name = f'generate_img_epoch_{epoch + 1}.png'
    image_path = os.path.join(CP_dir, image_name)
    plt.imshow(generate_image_np, cmap='gray')
    plt.imsave(image_path, generate_image_np, cmap='gray')

    plt.title(f'Generated Image - Epoch {epoch + 1}')
    plt.show()

    # model checkpoints
    checkpoint_vae = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt_vq.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint_vae, os.path.join(CP_dir, f'vae_checkpoint_epoch_{epoch + 1}.pth'))


In [None]:
from torchvision import datasets

mean = torch.tensor([0.491, 0.482, 0.447], device=device)
std = torch.tensor([0.247, 0.243, 0.262], device=device)

data_aug = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Normalize(mean=mean, std=std),
    transforms.Grayscale()])
train_set = datasets.CIFAR10(root='.', 
                             train=True, 
                             download=True,
                             transform=data_aug)

Batches = DataLoader(train_set,
                          batch_size=128, 
                          shuffle=True)

In [None]:
print(len(Batches))

In [None]:
variance=0.06328692405746414

In [None]:

for i, (data,_) in enumerate(Batches,0):
    print(data.size())
    inp= data.to(device)
    inputs = inp
    fig,ax=plt.subplots(5,5,figsize=(14,14))
    sample=[inputs[i] for i in range(25)]
    images=load_image(filenames)
    idx=0
    print(sample[idx].cpu().shape)
    for i in range(5):
        for j in range(5):
            ax[i,j].imshow(np.transpose(sample[idx].cpu(), (1,2,0)), cmap = 'gray')
            idx+=1
    # Train VAE
    opt_vq.zero_grad()
    reconstructions, _, q_loss = model(inputs)
    print(reconstructions[0].shape)
    fig,ax=plt.subplots(5,5,figsize=(14,14))
    sample=[reconstructions[i] for i in range(25)]
    images=load_image(filenames)
    idx=0
    for i in range(5):
        for j in range(5):
            ax[i,j].imshow(np.transpose(sample[idx].cpu().detach().numpy(), (1,2,0)), cmap = 'gray')
            idx+=1
    break
