## 导入库

In [1]:
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

## 超参数设置

In [2]:
batch_size = 512
seed = 1
epochs = 10
cuda = True
log_interval = 10
h_d = 512
l_d = 32
u_d = 1


torch.manual_seed(seed)

<torch._C.Generator at 0x1eddcb78bf0>

## 数据

In [3]:
import vae_mnist

In [4]:
hmnist = vae_mnist.HealingMNIST(seq_len=5, # 5 rotations of each digit
                                          square_count=0, # 3 out of 5 images have a square added to them
                                          square_size=5, # the square is 5x5
                                          noise_ratio=0.10, # on average, 20% of the image is eaten by noise,
                                          digits=range(10), # only include this digits
                                          test = False
                                         )

In [5]:
train_X = hmnist.train_images
train_Y = hmnist.train_targets
test_X = hmnist.test_images
test_Y = hmnist.test_targets

train_X = train_X.reshape(-1,28,28)
train_Y = train_Y.reshape(-1,28,28)
test_X = test_X.reshape(-1,28,28)
test_Y = test_Y.reshape(-1,28,28)

device = torch.device("cuda" if cuda else "cpu")

In [7]:
class HMNISTDataSet():
    def __init__(self, train_img, train_tar, test_img, test_tar, test = False, transform=None):
        self.test = test
        self.transform = transform

        if (self.test == False):
            self.images = train_img
            self.targets = train_tar


        else:      
            self.images = test_img
            self.targets = test_tar


    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        if self.transform is not None:
            img = self.transform(self.images[index].reshape(28,28,1))
            tar = self.transform(self.targets[index].reshape(28,28,1))
        return img, tar

In [8]:
train_set = HMNISTDataSet(train_X, train_Y, test_X, test_Y, test = False, transform = transforms.ToTensor())
test_set = HMNISTDataSet(train_X, train_Y, test_X, test_Y, test = True, transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

## 模型定义

In [9]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(784, h_d)
        self.fc2 = nn.Linear(h_d, 128)
        self.fc21 = nn.Linear(128, l_d)
        self.fc22 = nn.Linear(128, l_d)
        
        self.fc3 = nn.Linear(l_d, 128)
        self.fc4 = nn.Linear(128, h_d)
        self.fc5 = nn.Linear(h_d, 784)
        
    def encode(self, x):
        x = x.float(x)
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc21(h2), self.fc22(h2)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        z = z.float()
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc4(h3))
        return torch.sigmoid(self.fc5(4))
    
    def forward(self, x):
        mu1, logvar1 = self.encode(x.view(-1, 784))
        z1 = self.reparameterize(mu1, logvar1)
        return self.decode(z1), mu1, logvar1