In [91]:
import os 
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [92]:
base_path = './Data preprocessed/CSM1'

In [93]:
folder_list = os.listdir('./Data preprocessed/CSM1')
practice_folder_list=[]

for fd in folder_list:
    if '9' in fd:
        practice_folder_list.append(fd)
practice_folder_list =[os.path.join(base_path,pfl) for pfl in practice_folder_list]

true_img_list = [os.path.join(practice_folder_list[0],tip) for tip in os.listdir(practice_folder_list[0])]
false_img_list = [os.path.join(practice_folder_list[1],tip) for tip in os.listdir(practice_folder_list[1])]

In [94]:
zzz =cv2.imread(true_img_list[0],cv2.IMREAD_GRAYSCALE)

In [95]:
train_imgs = np.zeros((len(true_img_list),1,300,300))

for idx,img in enumerate(true_img_list):
    tmp_grayscale_img = cv2.imread(img,cv2.IMREAD_GRAYSCALE)
    tmp_grayscale_img = cv2.resize(tmp_grayscale_img,(300,300))
    train_imgs[idx,0,:,:] = tmp_grayscale_img
    

In [97]:
kernel_size = 4
stride = 1
padding = 0
init_kernel = 16

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Conv2d(
            in_channels=1, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc2 = nn.Conv2d(
            in_channels=init_kernel, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc3 = nn.Conv2d(
            in_channels=init_kernel*2, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc4 = nn.Conv2d(
            in_channels=init_kernel*4, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc5 = nn.Conv2d(
            in_channels=init_kernel*8, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )

        # decoder 
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec5 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=1, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = self.enc5(x)

        # get `mu` and `log_var`
        mu = x
        log_var = x

        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        reconstruction = torch.sigmoid(self.dec5(x))
        return reconstruction, mu, log_var

In [98]:
vae = VAE()
optimizer = torch.optim.Adam(params=vae.parameters(), lr=0.0001, weight_decay=1e-5)
def vae_loss(recon_x,x,mu,log_var):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1,1,300,300), x.view(-1, 1,300,300), reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss

In [102]:
def fit(model, dataloader,epochs):
    model.train()
    running_loss = 0.0
    for epoch in range(0,epochs): 
        print(f"Epoch {epoch+1}")
        for i,data in enumerate(dataloader): 
            optimizer.zero_grad()
            data =  torch.tensor(data, dtype = torch.float32)
            reconstruction, mu, logvar = model(data)
            loss = vae_loss(reconstruction,data, mu, logvar)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        train_loss = running_loss/len(dataloader.dataset)
        print(f"Train Loss: {train_loss:.4f}")
    return train_loss

In [103]:
train_dataloader = DataLoader(train_imgs, batch_size=16,
                                         shuffle=True)

In [None]:
fit(vae,train_dataloader,10)

Epoch 1


  data =  torch.tensor(data, dtype = torch.float32)


In [None]:
res = vae.forward(torch.tentrain_imgs[0]torch.tensor(data, dtype = torch.float32).view((1,1,28,28)))