# Welcome to VQ AutoEncoder clutstering tutorial!
ここでは、VQ (Vector Quantized) AutoEncoderによるでは、教師なしクラスタリング手法を行います。

### 目次

### 注意事項
このチュートリアルは演算負荷が高いためGPU環境が必要です。  
Google Colab で実行している方は、ページ上部から**ランタイム** &rarr; **ランタイムのタイプを変更** をクリックし **ハードウェアアクセラレータ** を *None* から *GPU* に変更してください。  
GPUが使用可能かどうかは次のコードブロックを実行することで分かります。

In [None]:
import torch
print("GPU:",torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using:",device)

## ライブラリのインストールとインポート

In [None]:
!pip install pytorch-lightning 
!pip install torchsummaryX 

In [None]:
import torch
import torch.nn as nn 
import pytorch_lightning as pl
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt 
from torchvision.utils import make_grid
from typing import *
from datetime import datetime
from torchsummaryX import summary
from torch.utils import data as dutil
from torchvision import transforms
from pytorch_lightning import loggers as pl_loggers
import os


## アルゴリズムの説明

## モデル

### Quantizing layer

In [None]:
class Quantizing(nn.Module):

    __initialized:bool = True

    def __init__(
        self, num_quantizing:int, quantizing_dim:int, _weight:torch.Tensor = None,
        initialize_by_dataset:bool = True, mean:float = 0.0, std:float = 1.0,
        dtype:torch.dtype = None, device:torch.device = None
        ):
        super().__init__()
        assert num_quantizing > 0
        assert quantizing_dim > 0
        self.num_quantizing = num_quantizing
        self.quantizing_dim = quantizing_dim
        self.initialize_by_dataset = initialize_by_dataset
        self.mean,self.std = mean,std

        if _weight is None:
            self.weight = nn.Parameter(
                torch.empty(num_quantizing, quantizing_dim ,dtype=dtype,device=device)
            )
            nn.init.normal_(self.weight, mean=mean, std=std)

            if initialize_by_dataset:
                self.__initialized = False
                self.__initialized_length= 0

        elif type(_weight) is torch.Tensor:
            assert _weight.dim() == 2
            assert _weight.size(0) == num_quantizing
            assert _weight.size(1) == quantizing_dim
            self.weight = nn.Parameter(_weight.to(device).to(dtype))

        else:
            raise ValueError("Weight type is unknown type! {}".format(type(_weight)))

    def forward(self,x:torch.Tensor) -> Tuple[torch.Tensor]:
        """
        x   : shape is (*, E), and weight shape is (Q, E). 
        return -> ( quantized : shape is (*, E), quantized_idx : shape is (*,) )
        """
        input_size = x.shape
        h = x.view(-1,self.quantizing_dim) # shape is (B,E)

        if not self.__initialized and self.initialize_by_dataset:
            getting_len = self.num_quantizing - self.__initialized_length
            init_weight = h[torch.randperm(len(h))[:getting_len]]
            
            _until = self.__initialized_length + init_weight.size(0)
            self.weight.data[self.__initialized_length:_until] = init_weight
            self.__initialized_length = _until
            print('replaced weight')

            if _until >= self.num_quantizing:
                self.__initialized = True
                print('initialized')
        
        delta = self.weight.unsqueeze(0) - h.unsqueeze(1) # shape is (B, Q, E)
        dist =torch.sum(delta*delta, dim=-1) # shape is (B, Q)
        q_idx = torch.argmin(dist,dim=-1) # shape is (B,)
        q_data = self.weight[q_idx] # shape is (B, E)

        return q_data.view(input_size), q_idx.view(input_size[:1])
    
    @property
    def is_initialized(self):
        return self.__initialized

    @is_initialized.setter
    def is_initialized(self, b:bool):
        self.__initialized = b
        if b:
            self.__initialized_length= num_quantizings
        else:
            self.__initialized_length = 0
    



### Encoder and Decoder

In [None]:

class Encoder(nn.Module):
    def __init__(self,h):
        super().__init__()
        self.channels = h.channels
        self.width = h.width
        self.height = h.height
        self.input_size = (1,h.channels,h.width,h.height)
        self.output_size = (1,h.quantizing_dim)
        self.h = h

        self.layers = nn.Sequential(
            nn.Flatten(1),
            nn.Linear(784,256),nn.ReLU(),
            nn.Linear(256,128),nn.ReLU(),
            nn.Linear(128,h.quantizing_dim),nn.Tanh(),
        )

    def forward(self,x:torch.Tensor):
        y = self.layers(x)
        return y
    
    def summary(self):
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

In [None]:
class Decoder(nn.Module):
    def __init__(self,h):
        super().__init__()
        self.channels = h.channels
        self.width = h.width
        self.height = h.height
        self.input_size = (1,h.quantizing_dim)
        self.output_size = (1,h.channels,h.width,h.height)
        self.h = h

        self.layers = nn.Sequential(
            nn.Linear(h.quantizing_dim,128),nn.ReLU(),
            nn.Linear(128,256),nn.ReLU(),
            nn.Linear(256,784),nn.Sigmoid(),
        )
    
    def forward(self,x:torch.Tensor):
        y = self.layers(x)
        y = y.view(-1,self.channels,self.height,self.width)
        return y

    def summary(self):
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

### VQ AutoEncoder
学習の処理を簡単に書くために、`pytorch-lightning`というライブラリを使用します。
`pytorch-lightning`はPyTorchのラッパーです。

In [None]:

class VQ_AutoEncoder(pl.LightningModule):

    def __init__(self,h):
        super().__init__()
        self.model_name = h.model_name
        self.h = h
        self.num_quantizing = h.num_quantizing
        self.quantizing_dim = h.quantizing_dim
        self.lr = h.lr
        self.my_hparams_dict = h.get()

        # set criterion
        self.reconstruction_loss = nn.MSELoss()
        self.quantizing_loss = nn.MSELoss()
        
        # set histogram
        self.q_hist = torch.zeros(self.num_quantizing,dtype=torch.int)
        self.q_hist_idx = np.arange(self.num_quantizing)
        # set layers
        self.encoder = Encoder(h)
        self.quantizer = Quantizing(h.num_quantizing,h.quantizing_dim)
        self.decoder = Decoder(h)

        self.input_size = self.encoder.input_size
        self.output_size = self.input_size

    def forward(self,x:torch.Tensor):
        h = self.encoder(x)
        Qout,Qidx = self.quantizer(h)
        y = self.decoder(Qout)
        return y

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(),self.lr)
        return optim
    
    @torch.no_grad()
    def set_quantizing_weight(self,data_loader,device='cpu'):
        self.quantizer.is_initialized = False
        for batch in data_loader:
            data,_ = batch
            data = data.to(device)
            Eout = self.encoder(data)
            _ = self.quantizer(Eout)
            if self.quantizer.is_initialized:
                break

        torch.cuda.empty_cache()

    def on_fit_start(self) -> None:
        self.logger.log_hyperparams(self.my_hparams_dict)

    def training_step(self,batch,idx):
        data,_  = batch
        self.view_data = data
        Eout = self.encoder(data)
        Qtgt = Eout.detach()
        Qout,Qidx = self.quantizer(Qtgt)
        out = self.decoder(Eout)

        # loss
        r_loss = self.reconstruction_loss(out,data)
        q_loss = self.quantizing_loss(Qout,Qtgt)
        loss = r_loss + q_loss

        # log
        rq_loss = self.reconstruction_loss(self.decoder(Qout),data)
        self.log('loss',loss)
        self.log('reconstruction_loss',r_loss)
        self.log('quantizing_loss',q_loss)
        self.log('reconstructed_quantizing_loss',rq_loss)

        idx,count = torch.unique(Qidx,return_counts = True)
        self.q_hist[idx.cpu()] += count.cpu()
        return loss

    @torch.no_grad()
    def on_epoch_end(self) -> None:
        if (self.current_epoch+1) % self.h.view_interval ==0:
            # image logging
            data = self.view_data[:self.h.max_view_imgs].float()
            data_len = len(data)
            Eout = self.encoder(data)
            Qout,Qidx = self.quantizer(Eout)
            out = self.decoder(Eout)
            Qdecoded = self.decoder(Qout)

            grid_img = make_grid(torch.cat([data,out,Qdecoded],dim=0),nrow=data_len)
            self.logger.experiment.add_image("MNIST Quantizings",grid_img,self.current_epoch)

            # histogram logging
            fig = plt.figure(figsize=(6.4,4.8))
            ax = fig.subplots()
            ax.bar(self.q_hist_idx,self.q_hist)
            
            quantized_num = len(self.q_hist[self.q_hist!=0])
            q_text = f'{quantized_num}/{self.num_quantizing}'
            ax.text(0.9,1.05,q_text,ha='center',va='center',transform=ax.transAxes,fontsize=12)
            ax.set_xlabel('weight index')
            ax.set_ylabel('num')
            self.logger.experiment.add_figure('Quantizing Histogram',fig,self.current_epoch)
            
        self.q_hist.zero_()

    def summary(self,tensorboard=False):
        from torch.utils.tensorboard import SummaryWriter
        dummy = torch.randn(self.input_size)
        summary(self,dummy)

        if tensorboard:
            writer = SummaryWriter(comment=self.model_name,log_dir="VQAE_log")
            writer.add_graph(self,dummy)
            writer.close()


`pytorch-lightning`では下のような形でモデルを定義することができます。  
```python
class ModelName(pl.LightningModule):

    def __init__(self, arguments): # required
        super().__init__()
        ################################################
        # この中でレイヤーや、criterion(Loss関数)を定義します。
        self.criterion = SomeLossFunc()
        self.layer = SomeLayers()
        ################################################

    def forward(self, input): # not required (when training)
        ################################################
        # この中にデータの流れを書きますが、学習する際にこの
        # forward methodは使われません。
        output = self.layer(input)
        ################################################
        return output

    def configure_optimizers(self): # required
        ################################################
        # この中でoptimizerを定義して、returnで返します。
        optim = torch.optim.SomeOptimizer(self.parameters(), lr=lr)
        ################################################
        return optim
    
    def training_step(self, batch, idx): # required
        ################################################
        # この中で学習する時のデータの流れを書きます。学習するとき
        # は"損失"まで計算し、それを return で返します。学習時に
        # 記録したい値は self.log("name", value) で記録するこ
        # とができます。
        input, answer = batch # extracting
        output = self.layer(input) # flow
        loss = self.criterion(output, answer) # calculate loss
        self.log("loss",loss) # log
        return loss # return (required)
        ################################################

    def validation_step(self, batch, idx): # not required when you don't need the validation.
        ################################################
        # この中に検証用データセットでの処理を書きます。この関数で
        # は値をreturnしなくて良いです。self.log("name", value)
        # で値を記録してください。
        # Trainerにvalidation用のDataLoaderを与えなかった場合、
        # この関数はつかわれません。　
        ################################################
```

Event drivenなので、他にも様々な関数が用意されています。pytorch_lightningの`Trainer`が自動的にそのタイミングでオーバーライドされた関数を呼び出します。
```python
model = ModelName(arguments)
trainer = pl.Trainer(gpus=1,precsion=16, max_epochs=10,...)
trainer.fit(model,train_loader, validation_loader)
```

### hyper parameter の定義

In [None]:
@dataclass
class hparam:
    model_name:str = "VQ_AutoEncoder"
    max_view_imgs = 16
    view_interval = 10

    lr:float = 0.001

    num_quantizing:int = 32
    quantizing_dim:int = 32

    channels:int = 1
    width:int = 28
    height:int = 28

    def get(self):
        return self.__dict__

## 学習する

 ### Tensor Board

In [None]:
%load_ext tensorboard
%tensorboard --logdir="VQAE_log"

### データセットのロード(MNIST)

In [None]:
from torchvision.datasets import MNIST
dataset = MNIST(
    "data",train=False,download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
    ])
)

### 時刻を取得する関数

In [None]:
def get_now(strf:str = '%Y-%m-%d_%H-%M-%S'):
    now = datetime.now().strftime(strf)
    return now

### パラメータの保存について

In [None]:
def param_dir():
    if not os.path.exists("params"):
        os.makedirs("params")
        
def save_params(model:VQ_AutoEncoder, now):
    param_dir()
    torch.save(model.encoder.state_dict(),f"params/{model.model_name}_{now}.encoder.pth")
    torch.save(model.decoder.state_dict(),f"params/{model.model_name}_{now}.decoder.pth")
    torch.save(model.quantizer.state_dict(),f"params/{model.model_name}_{now}.quantizing.pth")
    torch.save(model.state_dict(),f"params/{model.model_name}_{now}.pth")

### データローダーの定義

In [None]:
BATCH_SIZE = 1024
dataloader = dutil.DataLoader(
    dataset, BATCH_SIZE,shuffle=True, num_workers=0, pin_memory=True, drop_last=True
)

### pytorch lightning の key words


In [None]:
pl_kwds = {
    "gpus": 1 if torch.cuda.is_available() else 0,
    "precision": 16,
    "max_epochs": 500,
    "log_every_n_steps":5,
}

### 学習（1回目）

##### インスタンス

In [None]:
h = hparam(model_name="VQAE_pure")
model = VQ_AutoEncoder(h)
model.summary(True)

##### Quantinzing Weight をセット

In [None]:
model.set_quantizing_weight(dataloader)

##### 実行

In [None]:
logger = pl_loggers.TensorBoardLogger("VQAE_log/pure")
trainer = pl.Trainer(logger=logger, **pl_kwds)
trainer.fit(model, dataloader)
logger.close()
now = get_now()
print(now)

##### パラメータ保存

In [None]:
save_params(model,now)

##### なぜうまくいかないのか

### 学習（通常）

##### パラメータのロード

In [None]:
def load_trained_param(model:VQ_AutoEncoder):
    model.encoder.load_state_dict(torch.load("params/~.encoder.pth"))    
    model.decoder.load_state_dict(torch.load("params/~.decoder.pth"))

##### インスタンスと実行

In [None]:
def train(settings:Dict[str,Any], pl_kwds:Dict[str, Any]):
    h = hparam(**settings)
    model = VQ_AutoEncoder(h)
    load_trained_param(model)
    model.set_quantizing_weight(dataloader)
    logger = pl_loggers.TensorBoardLogger("VQAE_log/fromTraining")
    trainer = pl.Trainer(logger=logger, **pl_kwds)
    trainer.fit(model, dataloader)
    now = get_now()
    logger.close()
    save_params(model, now)

##### Ex) changing `num_quantizing`

In [None]:
pl_kwds = {
    "gpus": 1 if torch.cuda.is_available() else 0,
    "precision": 16,
    "max_epochs": 100,
    "log_every_n_steps":5,
}

In [None]:
num_quantizings = [8,32,128]
for nq in num_quantizings:
    setting = {
        "model_name":"vqae_changing_num_quantizing_{}".format(nq),
        "num_quantizing":nq,
    }
    train(setting, pl_kwds)

### 結果を解析する

### 考えること