In [1]:
from abc import ABC, abstractmethod
import torch
from torch import utils
from torch import nn
from torch import distributions
from torch import optim
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToPILImage

In [3]:
data_dir='./data/'

In [4]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_built()
if use_cuda:
    device = torch.device('cuda')
elif use_mps:
    device = torch.device('cpu')
else:
    device = torch.device('cpu')
cpu = torch.device('cpu')

In [5]:
default_batch_size = 256
loader_args = {'batch_size' : default_batch_size, 'shuffle' : True}
score_args = {'batch_size' : default_batch_size, 'shuffle' : False}
if use_cuda:
    loader_args.update({'pin_memory' : True})
    score_args.update({'pin_memory' : True})

In [6]:
class Reporter(ABC):
    @abstractmethod
    def report(self, typ, **metric):
        pass
    @abstractmethod
    def reset(self):
        pass

In [7]:
class SReporter(Reporter):
    def __init__(self):
        self.log = []
    def report(self, typ, **data):
        self.log.append((typ, data))
    def reset(self):
        self.log.clear()
    def loss(self, t):
        losses = []
        for (typ, data) in self.log:
            if typ == t:
                losses.append(data['loss'])
        return losses
    def loss(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count -= 1
        return float("inf")
    def eval_loss(self):
        return self.loss('eval')
    def train_loss(self):
        return self.loss('train')
    def eval_loss(self, idx):
        return self.loss('eval', idx)
    def train_loss(self, idx):
        return self.loss('train', idx)
    def get_record(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data
                    count -= 1
        return dict()
    def eval_record(self, idx):
        return self.get_record('eval', idx)
    def train_record(self, idx):
        return self.get_record('train', idx)

In [8]:
class VAELoss(nn.Module):
    def __init__(self):
        super(VAELoss, self).__init__()
    
    def forward(self, pred, target, mu, sig):
        recon_loss = ((target - pred)**2.).sum()
        dkl_loss = (sig**2. + mu**2. - torch.log(sig) - 0.5).sum()
        return recon_loss + dkl_loss

In [9]:
def relu_activation():
    return nn.ReLU(inplace=True)

In [10]:
def downsampling2DV2(in_c, out_c, stride, norm_layer):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 1, stride=stride),
        norm_layer(out_c),
    )

In [11]:
def upsampling2DV1(in_c, out_c, stride, norm_layer):
    return nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, 2, stride=stride),
        norm_layer(out_c),
    )

In [12]:
class ResidualLayer2DV4(nn.Module):
    def __init__(self, in_c, out_c, ksz, act_layer, norm_layer, stride=1):
        super(ResidualLayer2DV4, self).__init__()
        if in_c <= out_c:
            self.c1 = nn.Conv2d(in_c, out_c, ksz, stride=stride, padding=int((ksz-1)/2))
            self.c2 = nn.Conv2d(out_c, out_c, ksz, stride=1, padding=int((ksz-1)/2))
        else:
            self.c1 = nn.ConvTranspose2d(in_c, out_c, ksz+1, stride=stride, padding=int((ksz-1)/2))
            self.c2 = nn.ConvTranspose2d(out_c, out_c, ksz, stride=1, padding=int((ksz-1)/2))
        self.a1 = act_layer()
        self.a2 = act_layer()
        self.b1 = norm_layer(in_c)
        self.b2 = norm_layer(out_c)
        
        if in_c < out_c:
            self.residual = downsampling2DV2(in_c, out_c, stride, norm_layer)
        elif in_c > out_c:
            self.residual = upsampling2DV1(in_c, out_c, stride, norm_layer)
        elif stride > 1:
            self.residual = downsampling2DV2(in_c, out_c, stride, norm_layer)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        s = x
        x = self.b1(x)
        x = self.a1(x)
        x = self.c1(x)
        x = self.b2(x)
        x = self.a2(x)
        x = self.c2(x)
        s = self.residual(s)
        x = x + s
        return x

In [13]:
class ConvVariationalEncoderV1(nn.Module):
    def __init__(self, ic, chmuls, hmul):
        super(ConvVariationalEncoderV1, self).__init__()
        layer1 = []
        outmul = 1
        for mul in chmuls:
            layer1.append(nn.Conv2d(ic*outmul, ic*mul, (3,3), (2,2), (1,1)))
            layer1.append(nn.ReLU(inplace=True))
            outmul = mul
        self.layer1 = nn.ModuleList(layer1)
        self.mu_layer = nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1))
        self.sig_layer = nn.Sequential(
            nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1)),
            nn.Softplus(threshold=6),
        )
    
    def forward(self, x):
        for layer in self.layer1:
            x = layer(x)
        mu = self.mu_layer(x)
        sig = self.sig_layer(x)
        return (mu, sig)

In [14]:
class ConvVariationalEncoderV2(nn.Module):
    def __init__(self, ic, chmuls, hmul):
        super(ConvVariationalEncoderV2, self).__init__()
        layer1 = []
        outmul = 1
        for mul in chmuls:
            layer1.append(ResidualLayer2DV4(ic*outmul, ic*mul, 3, relu_activation, nn.BatchNorm2d, stride=2))
            outmul = mul
        self.layer1 = nn.ModuleList(layer1)
        self.mu_layer = nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1))
        self.sig_layer = nn.Sequential(
            nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1)),
            nn.Softplus(threshold=6),
        )
    
    def forward(self, x):
        for layer in self.layer1:
            x = layer(x)
        mu = self.mu_layer(x)
        sig = self.sig_layer(x)
        return (mu, sig)

In [15]:
class ConvDecoderV1(nn.Module):
    def __init__(self, ic, chmuls, hmul):
        super(ConvDecoderV1, self).__init__()
        layers = []
        outmul = hmul
        for mul in reversed(chmuls):
            layers.append(nn.ConvTranspose2d(ic*outmul, ic*mul, (4,4), (2,2), (1,1)))
            layers.append(nn.ReLU(inplace=True))
            outmul = mul
        layers.append(nn.ConvTranspose2d(ic*outmul, ic, (4,4), (2,2), (1,1)))
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [16]:
class ConvDecoderV2(nn.Module):
    def __init__(self, ic, chmuls, hmul):
        super(ConvDecoderV2, self).__init__()
        layers = []
        outmul = hmul
        for mul in reversed(chmuls):
            layers.append(ResidualLayer2DV4(ic*outmul, ic*mul, 3, relu_activation, nn.BatchNorm2d, stride=2))
            outmul = mul
        layers.append(ResidualLayer2DV4(ic*outmul, ic, 3, relu_activation, nn.BatchNorm2d, stride=2))
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [17]:
class ConvVariationalAutoEncoderV1(nn.Module):
    def __init__(self, ic, chmuls, hmul, dist):
        super(ConvVariationalAutoEncoderV1, self).__init__()
        self.encoder = ConvVariationalEncoderV1(ic, chmuls, hmul)
        self.decoder = ConvDecoderV1(ic, chmuls, hmul)
        self.dist = dist
    
    def forward(self, x, device):
        mu, sig = self.encoder(x)
        s = self.dist.sample(sig.shape).to(device)
        z = mu + sig * s
        x_h = self.decoder(z)
        return (x_h, mu, sig)

    def encode(self, x):
        mu, sig = self.encoder(x)
        return (mu, sig)

    def decode(self, mu, sig, device):
        s = self.dist.sample(mu.shape).to(device)
        z = mu + sig * s
        x_h = self.decoder(z)
        return x_h


In [18]:
class ConvVariationalAutoEncoderV2(nn.Module):
    def __init__(self, ic, chmuls, hmul, dist):
        super(ConvVariationalAutoEncoderV2, self).__init__()
        self.encoder = ConvVariationalEncoderV2(ic, chmuls, hmul)
        self.decoder = ConvDecoderV2(ic, chmuls, hmul)
        self.dist = dist
    
    def forward(self, x, device):
        mu, sig = self.encoder(x)
        s = self.dist.sample(sig.shape).to(device)
        z = mu + sig * s
        x_h = self.decoder(z)
        return (x_h, mu, sig)

    def encode(self, x):
        mu, sig = self.encoder(x)
        return (mu, sig)

    def decode(self, mu, sig, device):
        s = self.dist.sample(mu.shape).to(device)
        z = mu + sig * s
        x_h = self.decoder(z)
        return x_h

In [19]:
def vae_image_train(model, device, loader, optimizer, loss, epoch, reporter):
    model.train()
    total_loss = 0.
    for x, _ in loader:
        optimizer.zero_grad()
        x = x.to(device)
        x_h, mu, sig = model(x, device)
        l = loss(x_h, x, mu, sig)
        l.backward()
        optimizer.step()
        total_loss += l.item()
    total_loss /= float(len(loader.dataset))
    reporter.report(typ='train', loss=total_loss)
    print(f"Train Loss: {total_loss}")

In [20]:
def vae_image_validate(model, device, loader, loss, train_epoch, reporter):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            x_h, mu, sig = model(x, device)
            total_loss += loss(x_h, x, mu, sig)
    total_loss /= float(len(loader.dataset))
    reporter.report(typ='eval', loss=total_loss)

In [21]:
def vae_image_train_validate(
        model,
        device,
        train_loader,
        eval_loader,
        optimizer,
        scheduler,
        loss,
        total_epoch,
        patience,
        patience_decay,
        reporter,
):
    validation_loss = float("inf")
    patience_count = patience
    patience = int(patience * patience_decay)
    reset_patience = False
    for epoch in range(total_epoch):
        vae_image_train(model, device, train_loader, optimizer, loss, epoch, reporter)
        vae_image_validate(model, device, eval_loader, loss, epoch, reporter)
        new_validation_loss = reporter.eval_loss(-1)
        print(f"Epoch {epoch} VLoss: {new_validation_loss}")
        scheduler.step(new_validation_loss)
        if new_validation_loss < validation_loss:
            validation_loss = new_validation_loss
            patience_count = patience
            if reset_patience:
                patience = int(patience * patience_decay)
                reset_patience = False
        else:
            validation_loss = new_validation_loss
            patience_count -= 1
            reset_patience = True
            if patience_count <= 0:
                print(f"Improvement stopped. VLoss: {validation_loss}")
                break

In [22]:
### datasets

In [None]:
trainset = datasets.MNIST(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
evalset  = datasets.MNIST(root=data_dir, train=False, transform=transforms.ToTensor(), download=True)

In [23]:
trainset = datasets.CIFAR10(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
evalset = datasets.CIFAR10(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data/
Files already downloaded and verified


In [None]:
trainset = datasets.CIFAR100(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
evalset = datasets.CIFAR100(root=data_dir, train=False, transform=transforms.ToTensor(), download=True)

In [24]:
len(trainset), len(evalset)

(50000, 50000)

In [25]:
train_loader = utils.data.DataLoader(dataset=trainset, **loader_args)
eval_loader = utils.data.DataLoader(dataset=evalset, **score_args)

In [26]:
### training

In [27]:
norm_dist = distributions.Normal(0, 1)
in_channel = trainset[0][0].shape[0]
model = ConvVariationalAutoEncoderV2(in_channel, [in_channel*2], in_channel*4, norm_dist)
model = model.to(device)

In [28]:
learning_rate = 0.0001
total_epochs = 60
patience = 8
patience_decay = 0.9
optimizer = optim.Adam(model.parameters(recurse=True), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=patience/4, threshold=0.01)
loss = VAELoss()
reporter = SReporter()

In [None]:
vae_image_train_validate(model, device, train_loader, eval_loader, optimizer, scheduler, loss, total_epochs, patience, patience_decay, reporter)

Train Loss: 3939.8987575
Epoch 0 VLoss: 2688.265869140625
Train Loss: 2102.2113778125
Epoch 1 VLoss: 1743.615966796875
Train Loss: 1507.27000078125
Epoch 2 VLoss: 1343.02587890625
Train Loss: 1238.6249925
Epoch 3 VLoss: 1157.8890380859375
Train Loss: 1103.719003125
Epoch 4 VLoss: 1063.71533203125
Train Loss: 1034.91108
Epoch 5 VLoss: 1013.8659057617188
Train Loss: 999.460096875
Epoch 6 VLoss: 988.7303466796875
Train Loss: 979.3172359375
Epoch 7 VLoss: 972.2191162109375
Train Loss: 966.49983859375
Epoch 8 VLoss: 960.8378295898438
Train Loss: 957.39248375
Epoch 9 VLoss: 953.56982421875
Train Loss: 950.5582271875
Epoch 10 VLoss: 947.914794921875
Train Loss: 945.2179196875
Epoch 11 VLoss: 943.3482055664062
Train Loss: 940.7150909375
Epoch 12 VLoss: 938.4014282226562
Train Loss: 937.19405875
Epoch 13 VLoss: 936.0698852539062
Train Loss: 934.218559375
Epoch 14 VLoss: 932.833984375
Train Loss: 931.66979859375
Epoch 15 VLoss: 930.550048828125
Train Loss: 930.4141521875
Epoch 16 VLoss: 929.8525