In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
from pathlib import Path


In [None]:
class ImageNetDataset(Dataset):
    def __init__(self,paths):
        self.paths = paths
    
    def __getitem__(self,idx):
        path = self.paths[idx]
        transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])
        image = Image.open(path).convert('RGB')
        image = transform(image)
        
        return image
    
    def __len__(self):
        return len(self.paths)


train_folder_path = Path('C:\\Users\\user\\Training Models\\')
train_files = list(train_folder_path.rglob('*.jpg'))

test_dataset = ImageNetDataset(train_files)
test_data_loader = DataLoader(test_dataset,batch_size=128,shuffle=True)

In [None]:
class VQVAE(nn.Module):
   
    def __init__(self, num_embeddings = 512, embedding_dim = 64):
        super().__init__()
        self.act = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        
        self.embedding = nn.Embedding(num_embeddings = self.num_embeddings, embedding_dim = self.embedding_dim)
        
        self.enc_conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1) #64x64
        self.enc_conv2 = nn.Conv2d(in_channels = 32, out_channels = 128, kernel_size = (4,4), stride = 2, padding = 1) #32x32
        
         
        self.enc_conv_res11 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), stride = 1, padding = 1) #32x32
        self.enc_conv_res12 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (1,1), stride = 1, padding = 0) #32x32
        
        self.enc_conv_res21 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), stride = 1, padding = 1) #32x32
        self.enc_conv_res22 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (1,1), stride = 1, padding = 0) #32x32
        
        self.pre_quantization_conv = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding = 0)
        
        self.after_quantization_conv = nn.Conv2d(64, 128, kernel_size=1, stride=1, padding = 0)
        
        
        self.dec_conv1 = torch.nn.ConvTranspose2d(in_channels = 32, out_channels = 3, kernel_size = (2,2), stride=2, padding=0)
        self.dec_conv2 = torch.nn.ConvTranspose2d(in_channels = 128, out_channels = 32, kernel_size = (2,2), stride=2, padding=0)     
        
        self.dec_conv_res11 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), stride = 1, padding = 1)
        self.dec_conv_res12 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (1,1), stride = 1, padding = 0)
        
        self.dec_conv_res21 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), stride = 1, padding = 1)
        self.dec_conv_res22 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (1,1), stride = 1, padding = 0)
        
        # Commitment Loss Beta
        self.beta = 0.25
        
        
    def forward(self, x):
        # B, C, H, W
        x = self.act(self.enc_conv1(x))
        x = self.act(self.enc_conv2(x))
        x_clone = x.clone()
        x = self.act(self.enc_conv_res11(x))
        x = self.act(self.enc_conv_res12(x))
        x = x_clone + x
        x_clone = x.clone()
        x = self.act(self.enc_conv_res21(x))
        x = self.act(self.enc_conv_res22(x))
        x = x_clone + x 
        
        quant_input = self.pre_quantization_conv(x) #512 channels
        
        ## Quantization
        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1) # B, H, W, C
        quant_input = quant_input.reshape((B, H * W, C))
        
        # Compute pairwise distances
        dist = torch.cdist(quant_input, self.embedding.weight[None,:].repeat((B, 1, 1)) )
        
        # Find index of nearest embedding
        min_encoding_indices = torch.argmin(dist, dim=-1)
        
        # Select the embedding weights
        quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))
        
        commitment_loss = torch.mean((quant_input - quant_out.detach()) ** 2)
        codebook_loss = torch.mean((quant_input.detach() - quant_out) ** 2)
        quantize_losses = codebook_loss + self.beta*commitment_loss
        
        # Ensure straight through gradient
        quant_out = quant_input + (quant_out - quant_input).detach() # Градиент будет течь через quant_input, так как будто бы quant_out и quant_input  - это одно и тоже, при этом их значения отличаются

        quant_out = quant_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        x = self.after_quantization_conv(quant_out)
        
        x_clone = x.clone()
        x = self.act(self.dec_conv_res22(x))
        x = self.act(self.dec_conv_res21(x))
        x = x + x_clone
        
        
        x_clone = x.clone()
        x = self.act(self.dec_conv_res12(x))
        x = self.act(self.dec_conv_res11(x))
        x = x + x_clone
        
        
        x = self.act(self.dec_conv2(x))
        x = self.dec_conv1(x)
        output = self.sigmoid(x)
        
        
        return output, quantize_losses
        

In [None]:
new_dict = {}
temp_dict = torch.load('C:\\Users\\user\\Training Models\\model_params.pt')["state_dict"]

for a, b in temp_dict.items():
    new_dict[a[7:]] = b
    
#   new_dict['module.' + a] = b

In [None]:
model = VQVAE(512, 64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
#model = torch.nn.DataParallel(model, device_ids=(0,1,2,3))
model.load_state_dict(torch.load('C:\\Users\\user\\Training Models\\new_model_params.pt')["state_dict"])
#model.load_state_dict(new_dict)
loss_fn = nn.MSELoss(reduction="sum")
optim = torch.optim.Adam(model.parameters(), lr = 2e-4)

In [None]:
num_epochs = 70
losses = []
rec_losses = []
for epoch in tqdm(range(1, num_epochs + 1)):
    for index, data in enumerate(test_data_loader):

        img = data.to(device)
        
        model.train()
        
        res, loss_  = model(img)
        loss__ = loss_fn(res, img)
        loss_item = loss_ + loss__
        model.zero_grad()
        loss_item.sum().backward()
        optim.step()
        
        if index != 0 and index % 100 == 0:
            rec_losses.append(loss__.sum().item())
            losses.append(loss_item.sum().item())
            print(epoch, "rec loss:", round(rec_losses[-1],2), "loss:", round(losses[-1] / 100_000_000, 3) )
        if index % 500 == 0:    
            plt.figure(figsize=(20, 8))

            for i in range(1, 6):
                plt.subplot(2,5,i)
                plt.imshow(img[i - 1].cpu().detach().permute(1,2,0).numpy())
                plt.axis('off')
            for i in range(6, 11):
                plt.subplot(2,5, i)
                plt.imshow(res[i - 6].cpu().detach().permute(1,2,0).numpy())
                plt.axis('off')
            plt.show()
            
        if index != 0 and index % 1000 == 0:
            
            state = {
            'model': model,
            'state_dict': model.state_dict()
            }
            torch.save(state, 'C:\\Users\\user\\Training Models\\new_model_params.pt')
            print('model saved')
    
    