In [1]:
import torch
from torchvision import datasets, transforms

tensor_transform = transforms.Compose([
    transforms.ToTensor()
])

batch_size = 512
train_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = True,
									download = True,
									transform = tensor_transform)
test_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = False,
									download = True,
									transform = tensor_transform)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
							   batch_size = batch_size,
								 shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
							   batch_size = batch_size,
								 shuffle = False)

In [2]:
class Avger(list):
    def __str__(self):
        return f'{sum(self) / len(self):.4f}' if len(self) > 0 else 'N/A'
    
from tqdm.notebook import tqdm,trange
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class BatchPrepareBase:
    def process(self, x):
        raise NotImplementedError

In [4]:
import os
import torchvision.utils as vutils
from tqdm import trange

class SamplerBase:
    def __init__(self):
        os.makedirs('samples', exist_ok=True)
        os.makedirs('checkpoints', exist_ok=True)
        
    def calc(self, model, num):
        raise NotADirectoryError()
    
    @torch.no_grad()
    def sample(self, model, num, desc):
        x = self.calc(model, num)
        grid = vutils.make_grid(x, nrow=8)
        vutils.save_image(grid, f'samples/{desc}.png')
        torch.save(model.state_dict(), f'checkpoints/{desc}.pth')
        

In [5]:
class SanityVAE_BatchPrep(BatchPrepareBase):
    def process(self, x):
        return x.cuda()
        
class NoiseVAE_BatchPrep(BatchPrepareBase):
    def process(self, x):
        noise = torch.randn_like(x)
        return ((x + noise)/2).cuda()
        

In [6]:
class VAE_Sampler(SamplerBase):

    def calc(self,model,num):
        return model.decode(torch.randn(num, model.latent_dim).cuda())
    
    @torch.no_grad()
    def sample(self, model, num, desc):
        x_data = next(iter(test_loader))[0][:num].cuda()
        
        # generation
        x = self.calc(model, num)
        grid = vutils.make_grid(x, nrow=8)
        vutils.save_image(grid, f'samples/{desc}_gen.png')
        
        # reconstruction
        
        mu, logvar = model.encode(x_data)
        x_recon = model.decode(mu+torch.randn_like(mu)*torch.exp(logvar/2))
        grid = vutils.make_grid(x_recon, nrow=8)
        vutils.save_image(grid, f'samples/{desc}_recon.png')
        grid = vutils.make_grid(x_data, nrow=8)
        vutils.save_image(grid, f'samples/{desc}_data.png')
        
        torch.save(model.state_dict(), f'checkpoints/{desc}.pth')
        

In [7]:
from tqdm import tqdm
class VAE_trainer:
    def __init__(self, model, epochs,lr,desc,preparer:BatchPrepareBase,sampler:SamplerBase, sample_ep=4):
        self.model = model
        self.epochs = epochs
        self.sample_ep = sample_ep
        self.lr = lr
        self.preparer = preparer
        self.sampler = sampler
        self.desc = desc
        
    def run(self):
        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        for epoch in range(self.epochs):
            # opt.param_groups[0]['lr'] = self.lr * (1-epoch/self.epochs)
            self.model.train()
            train_loss = Avger(); rec_loss = Avger(); kl_loss = Avger()
            
            with tqdm(train_loader) as bar:
                for x,_ in bar:
                    inputs = self.preparer.process(x)
                    loss, rec, kl = self.model(inputs)
                    
                    opt.zero_grad()
                    loss.backward()
                    opt.step()
                    train_loss.append(loss.item()); rec_loss.append(rec.item()); kl_loss.append(kl.item())
                    bar.set_description(f'Epoch {epoch+1}/{self.epochs}, loss={train_loss}, rec={rec_loss}, kl={kl_loss}')
                    # break
                    
            if epoch == 0 or (epoch + 1) % self.sample_ep == 0 or epoch == self.epochs - 1:
                self.model.eval()
                self.sampler.sample(self.model, 64, f'ep{epoch+1}_{self.desc}')
                print(f'Epoch {epoch+1}, sample saved')

# VAE sanity

In [8]:
from vae import VAE

In [9]:
model = VAE().cuda()
print('number of parameters:', sum(p.numel() for p in model.parameters()))
trainer = VAE_trainer(model, 50, 1e-3, 'sanity', SanityVAE_BatchPrep(), VAE_Sampler())
trainer.run()

number of parameters: 944600


Epoch 1/50, loss=53.3275, rec=52.7250, kl=0.6025: 100%|██████████| 118/118 [00:01<00:00, 66.77it/s]


Epoch 1, sample saved


Epoch 2/50, loss=45.7120, rec=42.9938, kl=2.7182: 100%|██████████| 118/118 [00:01<00:00, 67.33it/s]
Epoch 3/50, loss=38.9316, rec=33.7055, kl=5.2261: 100%|██████████| 118/118 [00:01<00:00, 61.87it/s]
Epoch 4/50, loss=36.3023, rec=30.3727, kl=5.9295: 100%|██████████| 118/118 [00:01<00:00, 69.59it/s]


Epoch 4, sample saved


Epoch 5/50, loss=35.1426, rec=28.9163, kl=6.2263: 100%|██████████| 118/118 [00:01<00:00, 71.20it/s]
Epoch 6/50, loss=34.4248, rec=28.0152, kl=6.4096: 100%|██████████| 118/118 [00:01<00:00, 71.74it/s]
Epoch 7/50, loss=33.9588, rec=27.4305, kl=6.5283: 100%|██████████| 118/118 [00:01<00:00, 63.06it/s]
Epoch 8/50, loss=33.6262, rec=26.9684, kl=6.6578: 100%|██████████| 118/118 [00:01<00:00, 64.67it/s]


Epoch 8, sample saved


Epoch 9/50, loss=33.2716, rec=26.5151, kl=6.7565: 100%|██████████| 118/118 [00:01<00:00, 61.29it/s]
Epoch 10/50, loss=32.9813, rec=26.1608, kl=6.8205: 100%|██████████| 118/118 [00:01<00:00, 63.67it/s]
Epoch 11/50, loss=32.7866, rec=25.8793, kl=6.9073: 100%|██████████| 118/118 [00:01<00:00, 68.03it/s]
Epoch 12/50, loss=32.5499, rec=25.6144, kl=6.9355: 100%|██████████| 118/118 [00:01<00:00, 72.41it/s]


Epoch 12, sample saved


Epoch 13/50, loss=32.4562, rec=25.4358, kl=7.0204: 100%|██████████| 118/118 [00:01<00:00, 72.36it/s]
Epoch 14/50, loss=32.2337, rec=25.1889, kl=7.0448: 100%|██████████| 118/118 [00:01<00:00, 62.11it/s]
Epoch 15/50, loss=32.0522, rec=24.9594, kl=7.0928: 100%|██████████| 118/118 [00:01<00:00, 60.35it/s]
Epoch 16/50, loss=31.9188, rec=24.7761, kl=7.1426: 100%|██████████| 118/118 [00:01<00:00, 68.56it/s]


Epoch 16, sample saved


Epoch 17/50, loss=31.8146, rec=24.6668, kl=7.1478: 100%|██████████| 118/118 [00:01<00:00, 72.61it/s]
Epoch 18/50, loss=31.6990, rec=24.4893, kl=7.2097: 100%|██████████| 118/118 [00:01<00:00, 73.41it/s]
Epoch 19/50, loss=31.6114, rec=24.3921, kl=7.2194: 100%|██████████| 118/118 [00:01<00:00, 73.47it/s]
Epoch 20/50, loss=31.5191, rec=24.2740, kl=7.2451: 100%|██████████| 118/118 [00:01<00:00, 63.43it/s]


Epoch 20, sample saved


Epoch 21/50, loss=31.4328, rec=24.1562, kl=7.2766: 100%|██████████| 118/118 [00:01<00:00, 64.17it/s]
Epoch 22/50, loss=31.3414, rec=24.0494, kl=7.2920: 100%|██████████| 118/118 [00:01<00:00, 67.27it/s]
Epoch 23/50, loss=31.2361, rec=23.9165, kl=7.3195: 100%|██████████| 118/118 [00:01<00:00, 65.89it/s]
Epoch 24/50, loss=31.1812, rec=23.8335, kl=7.3477: 100%|██████████| 118/118 [00:01<00:00, 68.93it/s]


Epoch 24, sample saved


Epoch 25/50, loss=31.1253, rec=23.7612, kl=7.3642: 100%|██████████| 118/118 [00:01<00:00, 66.42it/s]
Epoch 26/50, loss=31.0521, rec=23.6683, kl=7.3838: 100%|██████████| 118/118 [00:02<00:00, 56.73it/s]
Epoch 27/50, loss=31.0035, rec=23.5924, kl=7.4111: 100%|██████████| 118/118 [00:01<00:00, 64.81it/s]
Epoch 28/50, loss=30.9243, rec=23.5179, kl=7.4064: 100%|██████████| 118/118 [00:01<00:00, 64.76it/s]


Epoch 28, sample saved


Epoch 29/50, loss=30.8713, rec=23.4367, kl=7.4347: 100%|██████████| 118/118 [00:01<00:00, 69.52it/s]
Epoch 30/50, loss=30.8196, rec=23.3655, kl=7.4541: 100%|██████████| 118/118 [00:01<00:00, 67.68it/s]
Epoch 31/50, loss=30.7958, rec=23.3155, kl=7.4803: 100%|██████████| 118/118 [00:01<00:00, 64.86it/s]
Epoch 32/50, loss=30.7159, rec=23.2279, kl=7.4880: 100%|██████████| 118/118 [00:01<00:00, 62.18it/s]


Epoch 32, sample saved


Epoch 33/50, loss=30.6609, rec=23.1566, kl=7.5043: 100%|██████████| 118/118 [00:01<00:00, 68.79it/s]
Epoch 34/50, loss=30.6354, rec=23.1090, kl=7.5264: 100%|██████████| 118/118 [00:01<00:00, 73.68it/s]
Epoch 35/50, loss=30.5866, rec=23.0623, kl=7.5243: 100%|██████████| 118/118 [00:01<00:00, 73.79it/s]
Epoch 36/50, loss=30.5626, rec=23.0152, kl=7.5474: 100%|██████████| 118/118 [00:01<00:00, 62.85it/s]


Epoch 36, sample saved


Epoch 37/50, loss=30.5194, rec=22.9433, kl=7.5762: 100%|██████████| 118/118 [00:01<00:00, 60.25it/s]
Epoch 38/50, loss=30.4751, rec=22.8952, kl=7.5799: 100%|██████████| 118/118 [00:01<00:00, 60.12it/s]
Epoch 39/50, loss=30.4096, rec=22.8181, kl=7.5915: 100%|██████████| 118/118 [00:01<00:00, 71.44it/s]
Epoch 40/50, loss=30.4115, rec=22.8034, kl=7.6080: 100%|██████████| 118/118 [00:01<00:00, 73.09it/s]


Epoch 40, sample saved


Epoch 41/50, loss=30.3377, rec=22.7280, kl=7.6098: 100%|██████████| 118/118 [00:01<00:00, 69.59it/s]
Epoch 42/50, loss=30.3521, rec=22.7375, kl=7.6146: 100%|██████████| 118/118 [00:01<00:00, 72.27it/s]
Epoch 43/50, loss=30.2775, rec=22.6604, kl=7.6171: 100%|██████████| 118/118 [00:01<00:00, 63.90it/s]
Epoch 44/50, loss=30.2645, rec=22.6219, kl=7.6426: 100%|██████████| 118/118 [00:02<00:00, 58.31it/s]


Epoch 44, sample saved


Epoch 45/50, loss=30.2205, rec=22.5707, kl=7.6498: 100%|██████████| 118/118 [00:01<00:00, 71.93it/s]
Epoch 46/50, loss=30.2165, rec=22.5380, kl=7.6784: 100%|██████████| 118/118 [00:01<00:00, 68.97it/s]
Epoch 47/50, loss=30.1951, rec=22.5050, kl=7.6901: 100%|██████████| 118/118 [00:01<00:00, 73.72it/s]
Epoch 48/50, loss=30.1453, rec=22.4502, kl=7.6951: 100%|██████████| 118/118 [00:01<00:00, 73.10it/s]


Epoch 48, sample saved


Epoch 49/50, loss=30.1244, rec=22.4212, kl=7.7032: 100%|██████████| 118/118 [00:01<00:00, 59.05it/s]
Epoch 50/50, loss=30.0479, rec=22.3436, kl=7.7044: 100%|██████████| 118/118 [00:01<00:00, 64.30it/s]


Epoch 50, sample saved


# VAE task

In [10]:
model = VAE().cuda()
print('number of parameters:', sum(p.numel() for p in model.parameters()))
trainer = VAE_trainer(model, 50, 1e-3, 'sanity', NoiseVAE_BatchPrep(), VAE_Sampler())
trainer.run()

number of parameters: 944600


Epoch 1/50, loss=209.6225, rec=209.5991, kl=0.0234: 100%|██████████| 118/118 [00:02<00:00, 41.22it/s]


Epoch 1, sample saved


Epoch 2/50, loss=209.0398, rec=208.6465, kl=0.3933: 100%|██████████| 118/118 [00:02<00:00, 44.75it/s]
Epoch 3/50, loss=208.7570, rec=208.1151, kl=0.6419: 100%|██████████| 118/118 [00:02<00:00, 50.63it/s]
Epoch 4/50, loss=208.5560, rec=207.8448, kl=0.7112: 100%|██████████| 118/118 [00:02<00:00, 42.76it/s]


Epoch 4, sample saved


Epoch 5/50, loss=208.5300, rec=207.7790, kl=0.7510: 100%|██████████| 118/118 [00:02<00:00, 40.09it/s]
Epoch 6/50, loss=208.4826, rec=207.7016, kl=0.7811: 100%|██████████| 118/118 [00:02<00:00, 57.54it/s]
Epoch 7/50, loss=208.4211, rec=207.5853, kl=0.8357: 100%|██████████| 118/118 [00:02<00:00, 51.29it/s]
Epoch 8/50, loss=208.3985, rec=207.4244, kl=0.9741: 100%|██████████| 118/118 [00:02<00:00, 48.26it/s]


Epoch 8, sample saved


Epoch 9/50, loss=208.3156, rec=207.2000, kl=1.1156: 100%|██████████| 118/118 [00:03<00:00, 32.26it/s]
Epoch 10/50, loss=208.2786, rec=207.0650, kl=1.2136: 100%|██████████| 118/118 [00:02<00:00, 52.98it/s]
Epoch 11/50, loss=208.2128, rec=206.8599, kl=1.3529: 100%|██████████| 118/118 [00:02<00:00, 44.83it/s]
Epoch 12/50, loss=208.0745, rec=206.6191, kl=1.4554: 100%|██████████| 118/118 [00:03<00:00, 36.51it/s]


Epoch 12, sample saved


Epoch 13/50, loss=208.0466, rec=206.4950, kl=1.5516: 100%|██████████| 118/118 [00:03<00:00, 34.25it/s]
Epoch 14/50, loss=207.9935, rec=206.3769, kl=1.6166: 100%|██████████| 118/118 [00:02<00:00, 50.20it/s]
Epoch 15/50, loss=207.9295, rec=206.2373, kl=1.6922: 100%|██████████| 118/118 [00:02<00:00, 47.67it/s]
Epoch 16/50, loss=207.9223, rec=206.1673, kl=1.7550: 100%|██████████| 118/118 [00:02<00:00, 42.54it/s]


Epoch 16, sample saved


Epoch 17/50, loss=207.9150, rec=206.1114, kl=1.8036: 100%|██████████| 118/118 [00:02<00:00, 47.87it/s]
Epoch 18/50, loss=207.8435, rec=206.0035, kl=1.8400: 100%|██████████| 118/118 [00:02<00:00, 55.12it/s]
Epoch 19/50, loss=207.8539, rec=205.9775, kl=1.8764: 100%|██████████| 118/118 [00:02<00:00, 56.02it/s]
Epoch 20/50, loss=207.7860, rec=205.8829, kl=1.9031: 100%|██████████| 118/118 [00:03<00:00, 36.28it/s]


Epoch 20, sample saved


Epoch 21/50, loss=207.8356, rec=205.9010, kl=1.9346: 100%|██████████| 118/118 [00:02<00:00, 41.19it/s]
Epoch 22/50, loss=207.6393, rec=205.6740, kl=1.9654: 100%|██████████| 118/118 [00:02<00:00, 49.36it/s]
Epoch 23/50, loss=207.7962, rec=205.7997, kl=1.9965: 100%|██████████| 118/118 [00:02<00:00, 52.54it/s]
Epoch 24/50, loss=207.7019, rec=205.6946, kl=2.0074: 100%|██████████| 118/118 [00:02<00:00, 46.59it/s]


Epoch 24, sample saved


Epoch 25/50, loss=207.8404, rec=205.8126, kl=2.0278: 100%|██████████| 118/118 [00:02<00:00, 44.25it/s]
Epoch 26/50, loss=207.7140, rec=205.6813, kl=2.0328: 100%|██████████| 118/118 [00:02<00:00, 45.02it/s]
Epoch 27/50, loss=207.7055, rec=205.6447, kl=2.0608: 100%|██████████| 118/118 [00:02<00:00, 53.27it/s]
Epoch 28/50, loss=207.6705, rec=205.5936, kl=2.0769: 100%|██████████| 118/118 [00:03<00:00, 31.75it/s]


Epoch 28, sample saved


Epoch 29/50, loss=207.7363, rec=205.6599, kl=2.0763: 100%|██████████| 118/118 [00:02<00:00, 50.99it/s]
Epoch 30/50, loss=207.7076, rec=205.5972, kl=2.1104: 100%|██████████| 118/118 [00:02<00:00, 39.95it/s]
Epoch 31/50, loss=207.6804, rec=205.5676, kl=2.1128: 100%|██████████| 118/118 [00:02<00:00, 40.56it/s]
Epoch 32/50, loss=207.6335, rec=205.5072, kl=2.1263: 100%|██████████| 118/118 [00:03<00:00, 36.77it/s]


Epoch 32, sample saved


Epoch 33/50, loss=207.7228, rec=205.5881, kl=2.1347: 100%|██████████| 118/118 [00:02<00:00, 46.24it/s]
Epoch 34/50, loss=207.5931, rec=205.4416, kl=2.1514: 100%|██████████| 118/118 [00:02<00:00, 43.79it/s]
Epoch 35/50, loss=207.6117, rec=205.4625, kl=2.1492: 100%|██████████| 118/118 [00:03<00:00, 33.39it/s]
Epoch 36/50, loss=207.6516, rec=205.4860, kl=2.1657: 100%|██████████| 118/118 [00:03<00:00, 33.43it/s]


Epoch 36, sample saved


Epoch 37/50, loss=207.6021, rec=205.4142, kl=2.1879: 100%|██████████| 118/118 [00:02<00:00, 47.25it/s]
Epoch 38/50, loss=207.6304, rec=205.4571, kl=2.1733: 100%|██████████| 118/118 [00:03<00:00, 38.99it/s]
Epoch 39/50, loss=207.6239, rec=205.4347, kl=2.1891: 100%|██████████| 118/118 [00:03<00:00, 33.54it/s]
Epoch 40/50, loss=207.6041, rec=205.4005, kl=2.2037: 100%|██████████| 118/118 [00:02<00:00, 57.32it/s]


Epoch 40, sample saved


Epoch 41/50, loss=207.5682, rec=205.3688, kl=2.1994: 100%|██████████| 118/118 [00:02<00:00, 52.02it/s]
Epoch 42/50, loss=207.5372, rec=205.3254, kl=2.2118: 100%|██████████| 118/118 [00:03<00:00, 38.16it/s]
Epoch 43/50, loss=207.5547, rec=205.3333, kl=2.2213: 100%|██████████| 118/118 [00:02<00:00, 43.07it/s]
Epoch 44/50, loss=207.5840, rec=205.3495, kl=2.2345: 100%|██████████| 118/118 [00:02<00:00, 53.37it/s]


Epoch 44, sample saved


Epoch 45/50, loss=207.5945, rec=205.3416, kl=2.2529: 100%|██████████| 118/118 [00:02<00:00, 45.01it/s]
Epoch 46/50, loss=207.6351, rec=205.3824, kl=2.2527: 100%|██████████| 118/118 [00:02<00:00, 47.60it/s]
Epoch 47/50, loss=207.5696, rec=205.3127, kl=2.2569: 100%|██████████| 118/118 [00:02<00:00, 41.47it/s]
Epoch 48/50, loss=207.5538, rec=205.2753, kl=2.2785: 100%|██████████| 118/118 [00:02<00:00, 51.19it/s]


Epoch 48, sample saved


Epoch 49/50, loss=207.6184, rec=205.3265, kl=2.2919: 100%|██████████| 118/118 [00:02<00:00, 52.04it/s]
Epoch 50/50, loss=207.4788, rec=205.1935, kl=2.2852: 100%|██████████| 118/118 [00:02<00:00, 46.93it/s]


Epoch 50, sample saved
