In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
from pathlib import Path


In [None]:
class ImageNetDataset(Dataset):
    def __init__(self,paths):
        self.paths = paths
    
    def __getitem__(self,idx):
        path = self.paths[idx]
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
        image = Image.open(path).convert('RGB')
        image = transform(image)
        
        return image
    
    def __len__(self):
        return len(self.paths)


train_folder_path = Path('C:\\Users\\user\\Training Models\\')
train_files = list(train_folder_path.rglob('*.jpg'))

test_dataset = ImageNetDataset(train_files)
test_data_loader = DataLoader(test_dataset,batch_size=128,shuffle=True)

In [None]:
class ResidualLayer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels = channels, out_channels = channels, kernel_size = (3,3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = channels, out_channels = channels, kernel_size = (1,1), stride = 1, padding = 0),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return x + self.conv(x)
    
    
class ResidualBlock(nn.Module):
    def __init__(self, channels, num_blocks):
        super().__init__()
        
        self.res_blocks = nn.Sequential(
            *[ResidualLayer(channels) for i in range(num_blocks)]
        )
        
    def forward(self, x):
        return self.res_blocks(x)
    

class Encoder(nn.Module):
    def __init__(self, in_channel, out_channel, stride_factor, num_blocks):
        super().__init__()
        blocks = []
        
        if stride_factor == 2:
            blocks = [
                nn.Conv2d(in_channel, out_channel // 2, (4,4), 2, 1),
                nn.ReLU(),
                nn.Conv2d(out_channel // 2, out_channel, (4,4), 2, 1),
                nn.ReLU()
            ]
            
        elif stride_factor == 4:
            blocks = [
                nn.Conv2d(in_channel, out_channel // 2, (4,4), 2, 1),
                nn.ReLU(),
                nn.Conv2d(out_channel // 2, out_channel, (4,4), 2, 1),
                nn.ReLU(),
                nn.Conv2d(out_channel, out_channel, (4,4), 2, 1),
                nn.ReLU(),
            ]
        blocks.append(ResidualBlock(out_channel, num_blocks))
        
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        return self.blocks(x)
    
class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel, stride_factor, num_blocks):
        super().__init__()
        blocks = []
        
        blocks.append(ResidualBlock(out_channel, num_blocks))
        
        if stride_factor == 2:
            blocks += [
                nn.ConvTranspose2d(out_channel, out_channel // 2, (2,2), 2, 0),
                nn.ReLU(),
                nn.ConvTranspose2d(out_channel // 2, in_channel, (2,2), 2, 0),
                nn.ReLU()
            ]
            
        elif stride_factor == 4:
            blocks += [
                nn.ConvTranspose2d(out_channel, out_channel, (2,2), 2, 0),
                nn.ReLU(),
            ]
        
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        return self.blocks(x)
    
class Quantize(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta):
        super().__init__()
        self.beta = beta
        self.embedding = nn.Embedding(num_embeddings = num_embeddings, embedding_dim = embedding_dim)
        
    def forward(self, x):
        quant_input = x
        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1) # B, H, W, C
        quant_input = quant_input.reshape((B, H * W, C))
        dist = torch.cdist(quant_input, self.embedding.weight[None,:].repeat((B, 1, 1)) )
        min_encoding_indices = torch.argmin(dist, dim=-1)
        quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))
        
        commitment_loss = torch.mean((quant_input - quant_out.detach()) ** 2)
        codebook_loss = torch.mean((quant_input.detach() - quant_out) ** 2)
        quantize_losses = codebook_loss + self.beta*commitment_loss
        
        quant_out = quant_input + (quant_out - quant_input).detach()

        quant_out = quant_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        
        return quant_out, quantize_losses
    

class VQVAE2(nn.Module):
    
    def __init__(self, num_embeddings = (512, 512), embedding_dim = (64, 64), in_channel = 3, out_channel = 128, num_blocks = 3):
        super().__init__()
        self.act = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.beta = 0.25
        
        self.enc_1 = Encoder(in_channel, out_channel,2, num_blocks)
        self.enc_2 = Encoder(in_channel, out_channel,4, num_blocks)
        
        self.pre_quantization_conv = nn.Conv2d(out_channel, embedding_dim[0], kernel_size=1, stride=1, padding = 0)
        self.pre_quantization_conv_cat = nn.Conv2d(out_channel*2, embedding_dim[0], kernel_size=1, stride=1, padding = 0)
        
        self.quant_1 = Quantize(num_embeddings[0], embedding_dim[0], self.beta)
        self.quant_2 = Quantize(num_embeddings[1], embedding_dim[1], self.beta)
        
        self.after_quantization_conv = nn.Conv2d(embedding_dim[0], out_channel, kernel_size=1, stride=1, padding = 0)
        
        self.dec_1 = Decoder(in_channel, out_channel * 2, 2, num_blocks)
        self.dec_2 = Decoder(out_channel, out_channel, 4, num_blocks)
        
        self.upsample_cat = nn.ConvTranspose2d(
            out_channel, out_channel, 4, stride=2, padding=1
        )
        
    def forward(self, x):
        # B, C, H, W
        x_1 = self.enc_1(x)
        x_2 = self.enc_2(x)
        x_2 = self.pre_quantization_conv(x_2)
        quant_input_2, quant_loss_2 = self.quant_2(x_2)
        quant_input_2 = self.after_quantization_conv(quant_input_2)
        quant_input_2_dec = self.dec_2(quant_input_2)
        
        x_1 = self.pre_quantization_conv_cat(torch.cat([x_1, quant_input_2_dec], dim = 1))
        quant_input_1, quant_loss_1 = self.quant_1(x_1)
        quant_input_1 = self.after_quantization_conv(quant_input_1)
        
        quant_input_1 = torch.cat([quant_input_1, self.upsample_cat(quant_input_2)], dim = 1)
        
        quant_input_1 = self.dec_1(quant_input_1)
        
        output = self.sigmoid(quant_input_1)
        
        return output, quant_loss_2, quant_loss_1

In [None]:
new_dict = {}
temp_dict = torch.load('C:\\Users\\user\\Training Models\\model_params.pt')["state_dict"]

for a, b in temp_dict.items():
    new_dict[a[7:]] = b
    
#   new_dict['module.' + a] = b

In [None]:
model = VQVAE2(num_embeddings = (512, 512), embedding_dim = (64, 64), in_channel = 3, out_channel = 128, num_blocks = 3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# if torch.cuda.device_count() > 1:
#     model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
    
# load_dict = torch.load('/kaggle/working/model_params.pt')["state_dict"]
# new_load_dict = {}
# for a, b in load_dict.items():
#     if a[:6] == 'module':
#         new_load_dict[a[6:]] = b
#     else:
#         new_load_dict = load_dict
#         break
        
# model.load_state_dict(new_load_dict)
loss_fn = nn.MSELoss(reduction="sum")
optim = torch.optim.Adam(model.parameters(), lr = 2e-4)

In [None]:
num_epochs = 100
losses = []
rec_losses = []
for epoch in tqdm(range(1, num_epochs + 1)):
    for index, data in enumerate(test_data_loader):

        img = data.to(device)
        
        model.train()
        
        res, loss_enc2, loss_enc1  = model(img)
        loss_dec = loss_fn(res, img)
        loss_item = loss_enc2 + loss_enc1 + loss_dec
        model.zero_grad()
        loss_item.sum().backward()
        optim.step()
        
        if index > 0 and index % 100 == 0:
            rec_losses.append(loss_dec.sum().item())
            losses.append(loss_item.sum().item())
            print(epoch, "rec loss:", round(rec_losses[-1] / 1_000, 4), "loss:", round(losses[-1] / 1_000_000_000_0, 4) )
        if index % 500 == 0:    
            plt.figure(figsize=(20, 8))

            for i in range(1, 6):
                plt.subplot(2,5,i)
                plt.imshow(img[i - 1].cpu().detach().permute(1,2,0).numpy())
                plt.axis('off')
            for i in range(6, 11):
                plt.subplot(2,5, i)
                plt.imshow(res[i - 6].cpu().detach().permute(1,2,0).numpy())
                plt.axis('off')
            plt.show()
            
        if index > 0 and index % 1000 == 0:
            state = {
            'model': model,
            'state_dict': model.state_dict()
            }
            torch.save(state, 'C:\\Users\\user\\Training Models\\new_model_vqvae2_params.pt')
            print('model saved')
    
    