# Welcome to VQ AutoEncoder clutstering tutorial!
ここでは、VQ (Vector Quantized) AutoEncoderによるでは、教師なしクラスタリング手法を行います。

## ライブラリのインストールとインポート

In [None]:
!pip install pytorch-lightning 
!pip install torchsummaryX 

In [2]:
import torch
import torch.nn as nn 
import pytorch_lightning as pl
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt 
from torchvision.utils import make_grid
from typing import *
from datetime import datetime
from torchsummaryX import summary
from torch.utils import data as dutil
from torchvision import transforms
from pytorch_lightning import loggers as pl_loggers


## アルゴリズムの説明

## モデル

### Quantizing layer

In [5]:
class Quantizing(nn.Module):

    __initialized:bool = True

    def __init__(
        self, num_quantizing:int, quantizing_dim:int, _weight:torch.Tensor = None,
        initialize_by_dataset:bool = True, mean:float = 0.0, std:float = 1.0,
        dtype:torch.dtype = None, device:torch.device = None
        ):
        super().__init__()
        assert num_quantizing > 0
        assert quantizing_dim > 0
        self.num_quantizing = num_quantizing
        self.quantizing_dim = quantizing_dim
        self.initialize_by_dataset = initialize_by_dataset
        self.mean,self.std = mean,std

        if _weight is None:
            self.weight = nn.Parameter(
                torch.empty(num_quantizing, quantizing_dim ,dtype=dtype,device=device)
            )
            nn.init.normal_(self.weight, mean=mean, std=std)

            if initialize_by_dataset:
                self.__initialized = False
                self.__initialized_length= 0

        elif type(_weight) is torch.Tensor:
            assert _weight.dim() == 2
            assert _weight.size(0) == num_quantizing
            assert _weight.size(1) == quantizing_dim
            self.weight = nn.Parameter(_weight.to(device).to(dtype))

        else:
            raise ValueError("Weight type is unknown type! {}".format(type(_weight)))

    def forward(self,x:torch.Tensor) -> Tuple[torch.Tensor]:
        """
        x   : shape is (*, E), and weight shape is (Q, E). 
        return -> ( quantized : shape is (*, E), quantized_idx : shape is (*,) )
        """
        input_size = x.shape
        h = x.view(-1,self.quantizing_dim) # shape is (B,E)

        if not self.__initialized and self.initialize_by_dataset:
            getting_len = self.num_quantizing - self.__initialized_length
            init_weight = h[torch.randperm(len(h))[:getting_len]]
            
            _until = self.__initialized_length + init_weight.size(0)
            self.weight.data[self.__initialized_length:_until] = init_weight
            self.__initialized_length = _until
            print('replaced weight')

            if _until >= self.num_quantizing:
                self.__initialized = True
                print('initialized')
        
        delta = self.weight.unsqueeze(0) - h.unsqueeze(1) # shape is (B, Q, E)
        dist =torch.sum(delta*delta, dim=-1) # shape is (B, Q)
        q_idx = torch.argmin(dist,dim=-1) # shape is (B,)
        q_data = self.weight[q_idx] # shape is (B, E)

        return q_data.view(input_size), q_idx.view(input_size[:1])
    
    @property
    def is_initialized(self):
        return self.__initialized



### Encoder and Decoder

In [4]:

class Encoder(nn.Module):
    def __init__(self,h):
        super().__init__()
        self.input_size = (1,h.channels,h.width,h.height)
        self.output_size = (1,h.quantizing_dim)
        self.h = h

        self.layers = nn.Sequential(
            nn.Flatten(1),
            nn.Linear(784,256),nn.ReLU(),
            nn.Linear(256,128),nn.ReLU(),
            nn.Linear(128,h.quantizing_dim),nn.Tanh(),
        )

    def forward(self,x:torch.Tensor):
        y = self.layers(x)
        return y
    
    def summary(self):
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

In [None]:
class Decoder(nn.Module):
    def __init__(self,h):
        super().__init__()
        self.input_size = (1,h.quantizing_dim)
        self.output_size = (1,h.channels,h.width,h.height)
        self.h = h

        self.layers = nn.Sequential(
            nn.Linear(h.quantizing_dim,128),nn.ReLU(),
            nn.Linear(128,256),nn.ReLU(),
            nn.Linear(256,784),nn.Sigmoid(),
        )
    
    def forward(self,x:torch.Tensor):
        y = self.layers(x)
        y = y.view(-1,self.channels,self.height,self.width)
        return y

    def summary(self):
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

### VQ AutoEncoder

In [6]:

class VQ_AutoEncoder(pl.LightningModule):

    def __init__(self,h):
        super().__init__()
        self.model_name = h.model_name
        self.h = h
        self.num_quantizing = h.num_quantizing
        self.quantizing_dim = h.quantizing_dim
        self.lr = h.lr
        self.my_hparams_dict = h.get()

        # set criterion
        self.reconstruction_loss = nn.MSELoss()
        self.quantizing_loss = nn.MSELoss()
        
        # set histogram
        self.q_hist = torch.zeros(self.num_quantizing,dtype=torch.int)
        self.q_hist_idx = np.arange(self.num_quantizing)
        # set layers
        self.encoder = Encoder(h)
        self.quantizer = Quantizing(h.num_quantizing,h.quantizing_dim)
        self.decoder = Decoder(h)

        self.input_size = self.encoder.input_size
        self.output_size = self.input_size

    def forward(self,x:torch.Tensor):
        h = self.encoder(x)
        Qout,Qidx = self.quantizer(h)
        y = self.decoder(Qout)
        return y

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(),self.lr)
        return optim
    
    @torch.no_grad()
    def set_quantizing_weight(self,data_loader,device='cpu'):
        for batch in data_loader:
            data,_ = batch
            data = data.to(device)
            Eout = self.encoder(data)
            _ = self.quantizer(Eout)
            if self.quantizer.is_initialized:
                break

        torch.cuda.empty_cache()

    def on_fit_start(self) -> None:
        self.logger.log_hyperparams(self.my_hparams_dict)

    def training_step(self,batch,idx):
        data,_  = batch
        self.view_data = data
        Eout = self.encoder(data)
        Qtgt = Eout.detach()
        Qout,Qidx = self.quantizer(Qtgt)
        out = self.decoder(Eout)

        # loss
        r_loss = self.reconstruction_loss(out,data)
        q_loss = self.quantizing_loss(Qout,Qtgt)
        loss = r_loss + q_loss

        # log
        rq_loss = self.reconstruction_loss(self.decoder(Qout),data)
        self.log('loss',loss)
        self.log('reconstruction_loss',r_loss)
        self.log('quantizing_loss',q_loss)
        self.log('reconstructed_quantizing_loss',rq_loss)

        idx,count = torch.unique(Qidx,return_counts = True)
        self.q_hist[idx.cpu()] += count.cpu()
        return loss

    @torch.no_grad()
    def on_epoch_end(self) -> None:
        if (self.current_epoch+1) % self.h.view_interval ==0:
            # image logging
            data = self.view_data[:self.h.max_view_imgs].float()
            data_len = len(data)
            Eout = self.encoder(data)
            Qout,Qidx = self.quantizer(Eout)
            out = self.decoder(Eout)
            Qdecoded = self.decoder(Qout)

            grid_img = make_grid(torch.cat([data,out,Qdecoded],dim=0),nrow=data_len)
            self.logger.experiment.add_image("MNIST Quantizings",grid_img,self.current_epoch)

            # histogram logging
            fig = plt.figure(figsize=(6.4,4.8))
            ax = fig.subplots()
            ax.bar(self.q_hist_idx,self.q_hist)
            
            quantized_num = len(self.q_hist[self.q_hist!=0])
            q_text = f'{quantized_num}/{self.num_quantizing}'
            ax.text(0.9,1.05,q_text,ha='center',va='center',transform=ax.transAxes,fontsize=12)
            ax.set_xlabel('weight index')
            ax.set_ylabel('num')
            self.logger.experiment.add_figure('Quantizing Histogram',fig,self.current_epoch)
            
        self.q_hist.zero_()

    def summary(self,tensorboard=False):
        from torch.utils.tensorboard import SummaryWriter
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

        if tensorboard:
            writer = SummaryWriter(comment=self.model_name)
            writer.add_graph(self,dummy)
            writer.close()


### hyper parameter の定義

In [7]:
@dataclass
class hparam:
    model_name:str = "VQ_AutoEncoder"
    max_view_imgs = 16
    view_interval = 10

    lr:float = 0.001

    num_quantizing:int = 32
    quantizing_dim:int = 32

    def get(self):
        return self.__dict__

## 学習する

### データセットのロード(MNIST)

### 実行

### 結果を見る

In [None]:
## 考えること