In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data_mnist', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data_mnist', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),batch_size=4, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [9]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        self.fc1 = nn.Linear(784,400)
        self.fc21=nn.Linear(400,20)
        self.fc22 = nn.Linear(400,20)
        self.fc3 = nn.Linear(20,400)
        self.fc4 = nn.Linear(400,784)
        
    def encode(self,x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1),self.fc22(h1)
    
    def reparameterize(self,mu,logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)
    def decode(self,z):
        h3 =F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self,x):
        mu,logvar = self.encode(x.view(-1,784))
        z = self.reparameterize(mu,logvar)
        return self.decode(z),mu,logvar
    
        

In [10]:
model = VAE().to(device)

In [11]:
optimizer = optim.Adam(model.parameters(),lr=1e-3) 

In [12]:
def loss_function(recon_x,x,mu,logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE +KLD

In [13]:
def train(epoch):
    model.train()
    train_loss =0
    for batch_idx,(data,_) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch,mu,logvar = model(data)
        loss = loss_function(recon_batch,data,mu,logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx%1000 ==999:
            
            print(f'{epoch+1} {batch_idx+1} loss:{train_loss/1000}')
            train_loss=0
    print('training ended')
                

In [15]:
def test(epoch):
    model.eval()
    test_loss=0
    with torch.no_grad():
        for i,(data,_) in enumerate(test_loader):
            data = data.to(device)
            recon_batch,mu,logvar = model(data)
            test_loss += loss_function(recon_batch,data,mu,logvar).item()
            if i==0:
                n = min(data.size(0),8)
                camparison = torch.cat([data[:n],recon_batch.view(4,1,28,28)[:n]])
                save_image(camparison.cpu(),'results/reconstruuction_' +str(epoch)+'.png',nrow=n)
    test_loss/= len(test_loader.dataset)
    print(f'====> Test set loss : {test_loss:.4f}')

In [16]:
for epoch in range(2):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(4,20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(4,28,28),'results/sample_'+str(epoch)+'.png')

1 1000 loss:-56458.29045059204
1 2000 loss:-77805.46035546875
1 3000 loss:-81718.43291015625
1 4000 loss:-84160.46915234375
1 5000 loss:-86284.5628828125
1 6000 loss:-87851.4858359375
1 7000 loss:-88940.1206484375
1 8000 loss:-89801.9001796875
1 9000 loss:-90491.81290625
1 10000 loss:-91282.230625
1 11000 loss:-91584.4523046875
1 12000 loss:-92159.7143515625
1 13000 loss:-92141.7953515625
1 14000 loss:-92233.72809375
1 15000 loss:-92700.31775
training ended
====> Test set loss : -23315.8543
2 1000 loss:-93016.101921875
2 2000 loss:-92919.80846875
2 3000 loss:-93131.5691796875
2 4000 loss:-92918.663546875
2 5000 loss:-93545.7057109375
2 6000 loss:-93802.20909375
2 7000 loss:-93457.8109765625
2 8000 loss:-93739.9311640625
2 9000 loss:-93937.0112578125
2 10000 loss:-94071.78959375
2 11000 loss:-93960.5300546875
2 12000 loss:-94281.1640078125
2 13000 loss:-94328.97396875
2 14000 loss:-95808.581890625
2 15000 loss:-96019.923015625
training ended
====> Test set loss : -24022.2221


In [17]:
 with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(3) + '.png')

ImportError: cannot import name 'image' from 'PIL' (/home/172mc001/anaconda3/lib/python3.7/site-packages/PIL/__init__.py)