In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision as vision
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn import functional as F

In [None]:
z_dim = 60
learning_rate = 5e-4
batch_size=16
device = torch.device("cuda")

In [None]:
class q_zx(nn.Module):
    def __init__(self, z_dims, in_channels=1, channels=[32, 64, 64]):
        super(q_zx, self).__init__()
        modules = []
        c1 = self._convblock(1, channels[0], 3, padding='same', bias=True)
        modules.append(*c1)
        c2 = self._convblock(channels[0], channels[1], 3, padding='same', bias=True)
        modules.append(*c2)
        c3 = self._convblock(channels[1], channels[2], 3, padding='same', bias=True)
        modules.append(*c3)
        self.model = nn.Sequential(*modules)
        self.mu_fc = nn.Linear(64*64*64, z_dims)
        self.logvar_fc = nn.Linear(64*64*64, z_dims)
        self.flatten = nn.Flatten()
        
    def _convblock(self, in_channels, out_channels, kernel_size, padding, bias):
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels=in_channels,
                         out_channels=out_channels,
                         kernel_size=kernel_size,
                         padding=padding,
                         bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        )
        return modules
        
    def forward(self, x):
        x = self.model(x)
        x = self.flatten(x)
        mu_z = self.mu_fc(x)
        logvar_z = self.logvar_fc(x)
        return mu_z, logvar_z
    
class p_xz(nn.Module):
    def __init__(self, z_dims, out_channels=1, channels=[48, 90, 90],batch_size=1):
        super(p_xz, self).__init__()
        self.batch_size = batch_size
        self.fc = nn.Linear(z_dims, 64*64*48)
        self.relu_fc = nn.ReLU()
        
        modules = []
        self.c1 = self._convblock(48, channels[0], kernel_size=3, padding='same', bias=True)
        modules.append(*self.c1)
        self.c2 = self._convblock(channels[0], channels[1], kernel_size=3, padding='same', bias=True)
        modules.append(*self.c2)
        self.c3 = self._convblock(channels[1], channels[2], kernel_size=3, padding='same', bias=True)
        modules.append(*self.c3)
        self.c = nn.Sequential(*modules)
        
        self.mu_conv = nn.Conv2d(in_channels=channels[2], out_channels=out_channels, kernel_size=3, padding='same', bias=False)
        
        
    def _convblock(self, in_channels, out_channels, kernel_size, padding, bias):
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          padding=padding,
                          bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        )
        return modules
    
    
    def forward(self, x):
        x = self.fc(x)
        x = self.relu_fc(x)
        x = x.view(-1, 48, 64, 64)
        x = self.c(x)
        image = self.mu_conv(x)
        return self.relu_fc(image)
    

class VAE(nn.Module):
    def __init__(self, z_dims=60):
        super(VAE, self).__init__()
        self.encoder = q_zx(z_dims=z_dims)
        self.decoder = p_xz(z_dims=z_dims, batch_size=batch_size)
        
    def encode(self, x):
        mu_z, logvar_z = self.encoder(x)
        # mu_z = (1, z_dims)
        # logvar_z = (1, z_dims)
        return mu_z, logvar_z
    
    def decode(self, x):
        mu_x = self.decoder(x)
        # mu_x = (batch, 1, 64, 64)
        # logvar_x = (batch, 1, 64, 64)
        return mu_x
    
    def sample(self, mu_z, logvar_z):
        std = torch.exp(0.5*logvar_z)
        eps = torch.randn_like(std)
        return mu_z + eps*std
    
    def loss_function(self, mu_x, x, mu_z, logvar_z):
        #mu_x = torch.flatten(mu_x, start_dim = 1)
        #logvar_x = torch.flatten(logvar_x, start_dim=1)
        BCE = F.mse_loss(mu_x, x, reduction='sum')
        KLD = torch.mean(-0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()))
        return loss_rec + KLD
    
    def forward(self, x):
        mu_z, logvar_z = self.encode(x)
        z_sampled = self.sample(mu_z, logvar_z)
        mu_x = self.decode(z_sampled)
        return mu_z, logvar_z, mu_x
        

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((64,64)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize(
            0.5, 0.5
        ),
    ]
)

dataset = vision.datasets.ImageFolder('/home/Student/s4606685/summer_research/oasis-3/png_data', transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
def get_data_sample():
    train_loader2 = DataLoader(dataset, batch_size=1, shuffle=True)
    loader = enumerate(train_loader2)
    data = next(loader)
    return data[1][0]

In [None]:
model = VAE()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        mu_z, logvar_z, mu_x = model(data)
        loss = model.loss_function(mu_x, data, mu_z, logvar_z)
        loss.backward()
        train_loss += loss.mean().item()
        optimizer.step()
        if batch_idx % 1200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))