In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
from typing import List

In [20]:
# NOTE
'''
input img tensor should look like `[B,3, 64, 64]`
'''

# CONSTANTS
IN_CHANNELS = 3
LATENT_DIM = 128 # temp
HIDDEN_DIM = [32, 64, 128, 256, 512]
ENCODER_HIDDEN_DIM = HIDDEN_DIM
DECODER_HIDDEN_DIM = HIDDEN_DIM.copy()
DECODER_HIDDEN_DIM.reverse()

DEVICE = torch.device('mps')

HIDDEN_DIM, ENCODER_HIDDEN_DIM, DECODER_HIDDEN_DIM

([32, 64, 128, 256, 512], [32, 64, 128, 256, 512], [512, 256, 128, 64, 32])

#### Encoder

In [3]:
encoder_modules = []
in_channels = IN_CHANNELS

for h_dim in ENCODER_HIDDEN_DIM:
    encoder_modules.append(
        nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=h_dim,
                      kernel_size=3, stride=2, padding= 1),
            nn.BatchNorm2d(h_dim),
            nn.LeakyReLU())
    )
    in_channels = h_dim

encoder = nn.Sequential(*encoder_modules)

#### Mean and Variant

In [4]:
fc_mu = nn.Linear(HIDDEN_DIM[-1]*4, LATENT_DIM)
fc_var = nn.Linear(HIDDEN_DIM[-1]*4, LATENT_DIM)

#### Decoder

In [51]:
decoder_modules = []

decoder_input = nn.Linear(LATENT_DIM, HIDDEN_DIM[-1]*4)

for i in range(len(DECODER_HIDDEN_DIM) - 1):
    decoder_modules.append(
        nn.Sequential(
            nn.ConvTranspose2d(in_channels=DECODER_HIDDEN_DIM[i],
                               out_channels=DECODER_HIDDEN_DIM[i+1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(DECODER_HIDDEN_DIM[i + 1]),
            nn.LeakyReLU())
    )

decoder = nn.Sequential(*decoder_modules)

#### Head

In [52]:
final_layer = nn.Sequential(
    nn.ConvTranspose2d(in_channels=DECODER_HIDDEN_DIM[-1],
                       out_channels=DECODER_HIDDEN_DIM[-1],
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       output_padding=1),
    nn.BatchNorm2d(DECODER_HIDDEN_DIM[-1]),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=DECODER_HIDDEN_DIM[-1],
              out_channels=3,
              kernel_size=3,
              padding=1),
    nn.Tanh()
)

In [53]:
def encode(input:torch.Tensor) -> List[torch.Tensor]:

    result = encoder(input)
    result = torch.flatten(result, start_dim=1)

    mu = fc_mu(result)
    log_var = fc_var(result)

    return [mu, log_var]

### ―――――――――――――――――――――――――――――― Test  ―――――――――――――――――――――――――――――――――――

In [54]:
input_tensor = torch.ones([32, 3, 64, 64])
encoded_tensor = encoder(input_tensor)
print(encoded_tensor.shape, encoded_tensor.flatten(1).shape)
# print(fc_mu.weight.shape)
# encode(input_tensor)


torch.Size([32, 512, 2, 2]) torch.Size([32, 2048])


### ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――

In [55]:
def decode(z: torch.Tensor) -> torch.Tensor:

    result = decoder_input(z)
    result = result.view(-1, 512, 2, 2)
    result = decoder(result)
    result = final_layer(result)
    return result

In [56]:
def reparameterize(mu: torch.Tensor, log_var: torch.Tensor):
    '''
    Monte Carlo -> M = 1
    '''
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std) # log_var이랑 동일한 shape으로 noise ~ N(0,1)에서 sampling -> 1번만 샘플링했음.
    # 만약 여러번 샘플링 한다면
    '''
    (eps1 * std + mu) + (eps2 * std + mu) + (eps3 * std + mu) + (eps4 * std + mu) + ... + (epsM * std + mu) 
    = (eps1 + ... + epsM)/M * std + mu
    '''
    return eps * std + mu


In [57]:
def forward(input:torch.Tensor)-> List[torch.Tensor]:
    mu, log_var = encode(input)
    z = reparameterize(mu, log_var)
    return [decode(z), input, mu, log_var]

## Quesion was: encoder의 파라미터 phi랑 decoder 파라미터 pi가 있는데, 한번에 backpropagate해도 되는건가.. 싶었는데 될것같긴한데, 일단 cs330-12듣고 다시 이어서 하기.

### ―――――――――――――――――――――――――――――― Test  ―――――――――――――――――――――――――――――――――――

In [62]:
## Testing Forward pass
input_tensor = torch.ones([32, 3, 64, 64]) # [B, C, H, W]

# 1. Encode
print('\n1. Encode =====')
encoded_tensor = encoder(input_tensor)
print('encoded_tensor: ', encoded_tensor.shape)

# 2. Flatten
print('\n1. Flatten =====')
flatten_tensor = encoded_tensor.flatten(start_dim=1)
print('flatten_tensor: ',flatten_tensor.shape)

# 3. Mean and log(var)
print('\n3. Mean and log(Var) =====')
mu = fc_mu(flatten_tensor)
log_var = fc_var(flatten_tensor)
print('mu.weights: ', fc_mu.weight.shape)
print('var.weights: ', fc_var.weight.shape)
print('mu: ', mu.shape)
print('var: ', log_var.shape)

# 4. re-parameterize
print('\n4. ReParameterize =====')
latent_var = reparameterize(mu,log_var)
print('latent_var(z): ', latent_var.shape)

# 5. decoder input
print('\n5. Decoder Input (FC) =====')
latent_input = decoder_input(latent_var)
print('latent_input: ',latent_input.shape)
print('decoder_input_linear: ',decoder_input.weight.shape)

# 6. Reshape Decoder Input
print('\n6. Reshape Decoder Input =====')
latent_input = latent_input.view(-1, 512, 2, 2)
print('latent_input (reshape): ', latent_input.shape)

# 7. Decode
print('\n7. Decode =====')
decoded_tensor = decoder(latent_input)
# print(decoder.named_parameters)
print('decoded_tensor: ',decoded_tensor.shape)

# 8. Final Layer
print('\n8. Final Layer =====')
out = final_layer(decoded_tensor)
print('out: ', out.shape)






1. Encode =====
encoded_tensor:  torch.Size([32, 512, 2, 2])

1. Flatten =====
flatten_tensor:  torch.Size([32, 2048])

3. Mean and log(Var) =====
mu.weights:  torch.Size([128, 2048])
var.weights:  torch.Size([128, 2048])
mu:  torch.Size([32, 128])
var:  torch.Size([32, 128])

4. ReParameterize =====
latent_var(z):  torch.Size([32, 128])

5. Decoder Input (FC) =====
latent_input:  torch.Size([32, 2048])
decoder_input_linear:  torch.Size([2048, 128])

6. Reshape Decoder Input =====
latent_input (reshape):  torch.Size([32, 512, 2, 2])

7. Decode =====
decoded_tensor:  torch.Size([32, 32, 32, 32])

8. Final Layer =====
out:  torch.Size([32, 3, 64, 64])


### ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――

In [63]:
def loss_function(*args, **kwargs) -> dict:

    recons = args[0]
    input = args[1]
    mu = args[2]
    log_var = args[3]

    kld_weight = kwargs['M_N'] #Account for minibatch samples from the dataset.

    recons_loss = F.mse_loss(recons, input) # EXP_z~q(z|x)[log(p(x|z))] - KLD(q(z|x) || p(z))

    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim =1 ), dim = 0) # Closed-form of KLD of q(z|x) = N(mu, var) and p(z) = N(0,1)
    
    loss= recons_loss + kld_weight * kld_loss

    return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}
    
    

In [64]:
def sample(num_samples:int,
           **kwargs):
    # since KL loss encouraged the encoder to map input x to latent distribution q(z|x) that closely approximates the Standard Normal Dist.
    z = torch.randn(num_samples, LATENT_DIM) 

    z = z.to(DEVICE)
    samples = decode(z)
    return samples

In [None]:
def generate(x:torch.Tensor) -> torch.Tensor:
    return forward(x)[0]

In [None]:
a = {'a':0,'b':1,'c':3}
print({f'{key}`': val for key, val in a.items()})