In [None]:
import os.path as path
import numpy as np
from typing import Tuple, List
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import gzip
from matplotlib import pyplot as plt
from IPython import display
from IPython.display import clear_output
import time
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import random

In [None]:
class Timer:
    # record the running time
    def __init__(self):
        self.times = []
        self.start()
    
    def start(self):
        self.tik = time.time()
        
    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]
    
    def avg(self):
        return sum(self.times)/len(self.times)
    
    def sum(self):
        return sum(self.times)
    
    def cumsum(self):
        return np.array(self.times).cumsum().tolist()

In [None]:
def try_gpu(i = 0):
    # if the gpu exist in PC, it will return gpu(i), otherwise return cpu()
    if torch.cuda.device_count() >= i+1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

In [None]:
class Accumulator:
    #在n个变量上累加
    def __init__(self,n):
        self.data = [0.0]*n
    
    def add(self, *args):
        self.data = [a+float(b) for a,b in zip(self.data, args)]
    
    def reset(self):
        self.data = [0.0]*len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    axes.set_xlabel(xlabel)
    axes.set_ylabel(ylabel)
    axes.set_xlim(xlim)
    axes.set_ylim(ylim)
    axes.set_xscale(xscale)
    axes.set_yscale(yscale)
    if legend:
        axes.legend(legend)
    axes.grid()

In [None]:
class Animator:
    #在动画中绘制数据
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(5,3)):
        #增量地绘制多条线
        if legend is None:
            legend = []
        self.fig, self.axes = plt.subplots(nrows,ncols,figsize=(5,3))
        if nrows*ncols == 1:
            self.axes = [self.axes, ]
        #使用lambda函数捕获参数
        self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None,None,fmts

    
    def add(self, x, y):
        #向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x]*n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x,y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b) 
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

In [None]:
dataset_folder = 'D:\Project\VAE\MNIST'
files_name = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

In [None]:
def load_mnist_data(files_name)->Tuple:
    with gzip.open(path.join(dataset_folder, files_name['train_img']),mode='rb') as data:
        train_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1,1,28,28)
    with gzip.open(path.join(dataset_folder, files_name['train_label']),mode='rb') as label:
        train_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    with gzip.open(path.join(dataset_folder, files_name['test_img']), mode='rb') as data:
        test_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1,1,28,28)
    with gzip.open(path.join(dataset_folder, files_name['test_label']), mode='rb') as label:
        test_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    return (train_img, train_label),(test_img, test_label)

In [None]:
class MNIST_dataset(Dataset):
    def __init__(self, data:List, label:List):
        self.__data = data
        self.__label = label
        
    def __getitem__(self, item):
        if not item < self.__len__():
            return f'Error, index{item} is out of range'
        return self.__data[item], self.__label[item]
    
    def __len__(self):
        return len(self.__data)

In [None]:
train_data, test_data = load_mnist_data(files_name)
train_data_image = train_data[0].to(torch.float)
train_data_label = train_data[1].to(torch.float)
train_data_image = train_data_image*0.00390625
train_data = (train_data_image,train_data_label)
train_dataset = MNIST_dataset(*train_data)
test_dataset = MNIST_dataset(*test_data)
len(train_dataset), len(test_dataset)

In [None]:
train_iter = DataLoader(train_dataset, batch_size = 100,shuffle = True)
#test_iter = DataLoader(test_dataset, batch_size = 100)

In [None]:
class VAENET(nn.Module):
    
    def __init__(self):
        super(VAENET, self).__init__()
        self.flat = nn.Flatten()
        #self.norm = nn.BatchNorm1d(784)
        self.linear1 = nn.Linear(28*28, 1000)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(1000, 500)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(500, 250)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(250, 25)
        self.linear5 = nn.Linear(250, 25)
        self.linear6 = nn.Linear(25, 250)
        self.relu4 = nn.ReLU()
        self.linear7 = nn.Linear(250, 500)
        self.relu5 = nn.ReLU()
        self.linear8 = nn.Linear(500, 1000)
        self.relu6 = nn.ReLU()
        self.linear9 = nn.Linear(1000, 784)
        self.gaussian = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
        
        
    def forward(self, x):
        y = self.flat(x)
        y = self.linear1(y)
        y = self.relu1(y)
        y = self.linear2(y)
        y = self.relu2(y)
        y = self.linear3(y)
        y1 = self.relu3(y)
        mu = self.linear4(y1)
        meansq = mu*mu
        logsd = self.linear5(y)
        sd = torch.exp(logsd)
        var = sd*sd
        noise = self.gaussian.sample(sample_shape = torch.Size([100,25]))
        noise = noise.squeeze(2)
        noise = noise.to(try_gpu())
        sdnoise = sd*noise
        sample = mu+sdnoise
        y2 = self.linear6(sample)
        y2 = self.relu4(y2)
        y2 = self.linear7(y2)
        y2 = self.relu5(y2)
        y2 = self.linear8(y2)
        y2 = self.relu6(y2)
        y2 = self.linear9(y2)
        return meansq, logsd, var, y2

In [None]:
def genVAE(model, device=None):
    
    gen_sigmoid = nn.Sigmoid()
    
    if isinstance(model, nn.Module):
        model.eval()
        if not device:
            device = next(iter(model.parameters())).device
    with torch.no_grad():
        float_arr = np.arange(-1.2, 1.3, 0.1)
        va_arr = np.ones((100,25))
        va_arr = va_arr*float_arr
        for i in range(100):
            for j in range(25):
                va_arr[i][j] = va_arr[i][j] - 3*random.random()
        va_arr = va_arr.astype(np.float32)
        gen_noise = torch.tensor(va_arr)
        gen_noise = gen_noise.to(try_gpu())
        gen_noise = model.linear6(gen_noise)
        gen_noise = model.relu4(gen_noise)
        gen_noise = model.linear7(gen_noise)
        gen_noise = model.relu5(gen_noise)
        gen_noise = model.linear8(gen_noise)
        gen_noise = model.relu6(gen_noise)
        gen_noise = model.linear9(gen_noise)
        gen_noise = gen_sigmoid(gen_noise)
        
    return gen_noise

In [None]:
def kldiv(meansq, logsd, var):
    divloss = 0.5*meansq + 0.5*var -logsd - 0.5
    divloss = torch.sum(divloss, dim = 0)*0.01
    divloss = torch.mean(divloss)
    return divloss

In [None]:
def SigmCrosEtpL(x, y2):
    sigmoid_instance = nn.Sigmoid()
    flat_instance = nn.Flatten()
    loss_fn = nn.CrossEntropyLoss()
    y1 = flat_instance(x)
    y2 = sigmoid_instance(y2)
    loss = loss_fn(y2, y1)
    return loss

In [None]:
model = VAENET()
print(model)

In [None]:
def train(model, train_iter, num_epochs, lr, device):
    def init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, mean = 0.0, std = 0.1)
            #torch.nn.init.xavier_uniform_(m.weight)
            torch.nn.init.zeros_(m.bias)
    model.apply(init_weights)
    print('training on', device)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = 0.0005)
    #optimizer = torch.optim.SGD(model.parameters(), lr = lr)
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = Timer(), len(train_iter)
    for epoch in range(num_epochs):
        metric = Accumulator(3)
        model.train()
        for i, (X,y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X = X.squeeze(1)
            X,y = X.to(device), y.to(device)
            meansq, logsd, var, y2 = model(X)
            l = SigmCrosEtpL(X, y2) + kldiv(meansq, logsd, var)
            #print(l)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l, 1, 1)
            timer.stop()
            train_l = metric[0]/metric[2]
            train_acc = metric[1]/metric[2]
            if (i+1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i+1) / num_batches, (train_l, train_acc, None))
        
    print(f'{metric[2]*num_epochs/timer.sum():.1f} examples/sec ' f'on {str(device)}')

In [None]:
lr, num_epochs = 0.0005, 30
train(model, train_iter, num_epochs, lr, try_gpu())

In [None]:
pre_image = genVAE(model)