In [1]:
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
from six.moves import xrange
import umap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
#加载Cifir-10
training_set = datasets.CIFAR10(root="data",train=True,download=True,transform=transforms.Compose([
    transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(1,1,1))
]))

validation_set = datasets.CIFAR10(root="data",train=False,download=True,transform=transforms.Compose([
    transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(1,1,1))
]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data\cifar-10-python.tar.gz to data
Files already downloaded and verified


In [None]:
#VQ 层
class VectorQuantizer(nn.Module):
    def __init__(self,num_channels:int,embedding_dim:int,commitment_cost:float):
        super.__init__()
        self._embedding_dim  = embedding_dim
        self._num_channels = num_channels

        self._embedding = nn.Embedding(self._num_channels,self._embedding_dim)
        # 给这个可学习的权重初始化，直接用这个权重作为Codebook了。这个weight: [numc,embeddim] = [k,D]
        self._embedding.weight.data.uniform_(-1/self._num_channels,1/self._num_channels)
        self._commitment_cost = commitment_cost



    def forward(self,inputs:torch.Tensor):
        inputs = inputs.permute(0,2,3,1).contiguous() #内存数据连续化
        input_shape = inputs.shape

        # 拉成二维
        flat_input = inputs.view(-1,self._embedding_dim)

        #计算距离 这里使用的事欧式距离。也是是余弦相似度。其实之间是有关系的。
        distance = (torch.sum(flat_input**2,dim=1,keepdim=True)+
                    torch.sum(self._embedding.weight**2,dim=1)-
                    2*torch.matmul(flat_input,self._embedding.weight.t()))

        #Ecoding
        encoding_indices = torch.argmin(distance,dim=1).unsqueeze(1) #增加维度，从一行n个，变成 [n,1]
        encodings = torch.zeros(encoding_indices.shape[0],self._num_channels,device=inputs.device) # 这个就是One-Hot阶段，生成对应的 [B*H*W,K]的过程

        #生成OneHot向量，encodings前面全是 0 ，这里就是把对应的位置变成1.
        encodings.scatter_(1,encoding_indices,1)

        #Quantize and Unflatten, encodings 和 Codebook矩阵相乘，乘就来的就是 [BWH,D]这个D就是码本里面的数据，最近邻的那个中心。（聚类中心？？？）
        quantized:torch.Tensor = torch.matmul(encodings,self._embedding_dim.weight).view(input_shape)

        #Loss
        e_latent_loss = F.mse_loss(quantized.detach(),inputs)
        q_latent_loss = F.mse_loss(quantized,inputs.detach())
        loss = q_latent_loss + e_latent_loss * self._commitment_cost

        #trick 因为编码器Inputs 没有办法进行梯度的传播，也就是连续求导，所以这样写，类似于参数重整化，一个道理。
        quantized = inputs + (quantized - inputs).detach()


        # 困惑度 ： 用来验证VQ是否在work ,这里就是信息熵。
        avg_probs = torch.mean(encodings,dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs+1e-10)))

        return loss,quantized.permute(0,3,1,2).contiguous(),perplexity,encodings





In [None]:
#残差部分
class Residual(nn.Module):
    def __init__(self,in_channels,num_hiddens,num_residual_hidden):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,out_channels=num_residual_hidden,kernel_size=3,stride=1,padding=1,bias=False),

            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hidden,out_channels=num_hiddens,kernel_size=1,stride=1,bias=False)
        )
    def forward(self,inputs):
        return inputs + self.block(inputs)

class ResidualStack(nn.Module):
    def __init__(self,in_channels,num_hiddens,num_residual_hidden,num_residual_layers):
        super().__init__()
        self._num_residual_layers = num_residual_layers
        # 这个写法 还挺新颖：
        self._layers = nn.ModuleList([
            [Residual(in_channels,num_hiddens,num_residual_hidden) for _ in range(self._num_residual_layers)]
        ])
    def forward(self,inputs):
        for i in range(self._num_residual_layers):
            inputs = self._layers[i](inputs)
        return F.relu(inputs)

In [None]:
#Encoder
class Encoder(nn.Module):
    def __init__(self,in_channels,num_hiddens,num_residual_hidden,num_residual_layers):
        super().__init__()

        self._conv1 = nn.Conv2d(in_channels,num_hiddens//2,kernel_size=4,stride=2,padding=1)
        self._conv2 = nn.Conv2d(in_channels=num_hiddens//2,out_channels=num_hiddens,kernel_size=4,stride=2,padding=1)
        self._conv3 = nn.Conv2d(in_channels=num_hiddens,out_channels=num_hiddens,kernel_size=3,stride=1,padding=1)

        self._residual_stack = ResidualStack(in_channels=num_hiddens,num_hiddens=num_hiddens,num_residual_layers = num_residual_layers,num_residual_hidden=num_residual_hidden)

    def forward(self,inputs):
        x = self._conv1(inputs)
        x = F.relu(x)

        x = self._conv2(x)
        x = F.relu(x)

        x = self._conv3(x)

        return self._residual_stack(x)


In [None]:
#Decoder
class Decoder(nn.Module):
    def __init__(self,in_channels,num_hiddens,num_residual_hidden,num_residual_layers):
        self._conv1 = nn.Conv2d(in_channels=in_channels,out_channels=num_hiddens,kernel_size=3,stride=1,padding=1)

        self._redidual_stack = ResidualStack(in_channels=num_hiddens,num_hiddens=num_hiddens,num_residual_layers=num_residual_layers,num_residual_hidden=num_residual_hidden)

        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens//2,
                                                kernel_size=4,
                                                stride=2,
                                                padding=1)

        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
                                                out_channels=3,
                                                kernel_size=4,
                                                stride=2,
                                                padding=1)

    def forward(self,inputs):
        x = self._conv1(inputs)

        x = self._redidual_stack(x)

        x = self._conv_trans_1(x)
        x = F.relu(x)

        return self._conv_trans_2(x)

In [None]:
#Train
batch_size = 256
num_training_updates = 15000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 64
# codebook 的 向量的条数；
num_embeddings = 512

commitment_cost = 0.25

learning_rate = 1e-3


In [None]:
training_loader = DataLoader(
    training_set,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True
)

In [None]:
validation_loader = DataLoader(
    validation_set,
    batch_size=32,
    shuffle=True,
    pin_memory=True
)

In [None]:
class Model(nn.Module):
    def __init__(self,num_hiddens,num_residual_layers,num_residual_hiddens,num_embeddings,embedding_dim,commitment_cost):
        super(Model, self).__init__()
        self._encoder = Encoder(3,num_hiddens=num_hiddens,num_residual_hidden=num_residual_hiddens,num_residual_layers=num_residual_layers)

        # 进入VQ之前的后处理，感觉处不处理没有区别。。。
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,out_channels=embedding_dim,kernel_size=1,stride=1)
        # VQ传的参数：Codebook 的 条数，维度；
        self._vq_vae = VectorQuantizer(num_channels=num_embeddings,embedding_dim=embedding_dim,commitment_cost=commitment_cost)

        self._decoder = Decoder(in_channels=embedding_dim,num_hiddens=num_hiddens,num_residual_layers=num_residual_layers,num_residual_hidden=num_residual_hiddens)

    def forward(self,inputs:torch.Tensor):
        z = self._encoder(inputs)
        z = self._pre_vq_conv(z)
        loss,quantized,perplexity,_ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return loss,x_recon,perplexity


In [None]:
model = Model(num_hiddens,num_residual_layers,num_residual_hiddens,num_embeddings,embedding_dim,commitment_cost).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(),lr = learning_rate,amsgrad=False)

In [None]:
model.train()

train_res_recon_error = []
train_res_perplexity = []

for i in range(num_training_updates):
    (data,_) = next(iter(training_loader))
    data = data.to(device)
    optimizer.zero_grad()

    vq_loss,data_recon,perplexity = model(data)
    recon_error = F.mse_loss(data_recon,data)

    loss = recon_error + vq_loss

    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item()) # item() 就是获取集合的所有元素，并且转换成元祖列表[( , )]
    train_res_perplexity.append(perplexity.item())

    if(i+1) % 100 ==0:
        print('%d iteration' % (i+1))


In [None]:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error,201,7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity,201,7)


f= plt.figure( figsize=( 16,8))
ax =f.add_subplot ( 1,2,1)
ax.plot ( train_res_recon_error_smooth)
ax.set_yscale( ' log' )
ax.set_title( 'smoothed NMSE. ')
ax.set_xlabel ( 'iteration ')


ax =f.add_subplot ( 1,2,2)
ax.plot (train_res_perplexity_smooth)
ax.set_title( 'Smoothed Average codebook usage (perplexity) . ' )
ax.set_xlabei ('ieration')


In [None]:
#model eval
model.eval()

(valid_originals,_) = next (iter(validation_loader))
valid_originals = valid_originals.to(device)

vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize,_,_= model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder (valid_quantize)


( train_originals,_) = next ( iter(training_loader) )
train_originals = train_originals.to(device)
_, train_reconstructions,_,_ = model._vq_vae(train_originals)
def show ( img ) :
    npimg = img.numpy ()
    fig = plt.imshow(np.transpose(npimg,(1,2,0 )),interpolation='nearest ' )
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
show ( make_grid(valid_reconstructions.cpu ( ).data)+0.5, )
show ( make_grid(valid_originals.cpu( )+0.5))


In [None]:
#View Embedding
proj = umap.UMAP (n_neighbors=3,
                    min_dist=0.1,
metric='cosine ' ).fit_transform(model._vg_vae._embedding.weight.data.cpu())

plt.scatter(proj[ :,0], proj[:,1],alpha=0.3)
