In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
# 튜닝 파라미터
kernel_size = 4
stride = 1
padding = 0
init_kernel = 16 

In [34]:
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Conv2d(
            in_channels=1, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc2 = nn.Conv2d(
            in_channels=init_kernel, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc3 = nn.Conv2d(
            in_channels=init_kernel*2, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc4 = nn.Conv2d(
            in_channels=init_kernel*4, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc5 = nn.Conv2d(
            in_channels=init_kernel*8, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        # decoder 
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec5 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=1, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var) 
        eps = torch.randn_like(std) 
        sample = mu + (eps * std) 
        return sample
 
    def forward(self, x):
        # 인코더
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = self.enc5(x)
        
        mu = x
        log_var = x
        
        z = self.reparameterize(mu, log_var)
 
        # 디코더
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        reconstruction = torch.sigmoid(self.dec5(x))
        return reconstruction, mu, log_var

In [35]:
class vae_decoder(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()
        # 디코더
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec5 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=1, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var) 
        eps = torch.randn_like(std) 
        sample = mu + (eps * std) 
        return sample
    
    def forward(self, mu, log_var):
        z = reparameterize(mu, log_var)
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        reconstruction = torch.sigmoid(self.dec5(x))
        return reconstruction

In [14]:
import scipy.io
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchvision.utils import save_image

In [15]:
# 학습 parameter
batch_size = 64
lr = 0.001
epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
# get the data into NumPy format
mat_data = scipy.io.loadmat('./input/frey_rawface.mat')
data = mat_data['ff'].T.reshape(-1, 1, 28, 20)
data = data.astype('float32') / 255.0
print(f"Number of images: {len(data)}")

Number of images: 1965


In [19]:
# divide the data into train and validation set
x_train = data[:-300]
x_val = data[-300:]
print(f"Training images: {len(x_train)}")
print(f"Validation images: {len(x_val)}")

Training images: 1665
Validation images: 300


In [20]:
class FreyDataset(Dataset):
    def __init__(self, X):
        self.X = X
        
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, index):
        data = self.X[index]
        return torch.tensor(data, dtype=torch.float)

In [21]:
train_data = FreyDataset(x_train)
val_data = FreyDataset(x_val)

train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)

In [22]:
model = ConvVAE().to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

criterion = nn.BCELoss(reduction='sum')

In [23]:
def final_loss(bce_loss, mu, logvar):
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [24]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        data= data
        data = data.to(device)
        data = data
        optimizer.zero_grad()
        
        reconstruction, mu, logvar = model(data)
        
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        
        loss.backward()
        running_loss += loss.item()
        
        optimizer.step()
        
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [25]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data= data
            data = data.to(device)
            data = data
            
            reconstruction, mu, logvar = model(data)
            
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data[:8], 
                                  reconstruction[:8]))
                save_image(both.cpu(), f"./outputs/output{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [26]:
train_loss = []
val_loss = []

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader)
    
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Epoch 1 of 20


27it [00:17,  1.51it/s]                                                                                                
5it [00:00,  5.17it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 384.7129
Val Loss: 367.8688
Epoch 2 of 20


27it [00:17,  1.54it/s]                                                                                                
5it [00:00,  5.46it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 365.6746
Val Loss: 359.6353
Epoch 3 of 20


27it [00:17,  1.54it/s]                                                                                                
5it [00:00,  5.16it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 361.5346
Val Loss: 357.5319
Epoch 4 of 20


27it [00:18,  1.45it/s]                                                                                                
5it [00:00,  5.31it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 359.7239
Val Loss: 355.1587
Epoch 5 of 20


27it [00:18,  1.48it/s]                                                                                                
5it [00:00,  5.60it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 357.8003
Val Loss: 354.9748
Epoch 6 of 20


27it [00:18,  1.45it/s]                                                                                                
5it [00:01,  4.01it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 356.6208
Val Loss: 354.5795
Epoch 7 of 20


27it [00:17,  1.52it/s]                                                                                                
5it [00:00,  5.42it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 356.0859
Val Loss: 353.8754
Epoch 8 of 20


27it [00:18,  1.48it/s]                                                                                                
5it [00:00,  5.32it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 355.6603
Val Loss: 352.7354
Epoch 9 of 20


27it [00:20,  1.32it/s]                                                                                                
5it [00:01,  4.60it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 355.2623
Val Loss: 352.3571
Epoch 10 of 20


27it [00:18,  1.46it/s]                                                                                                
5it [00:00,  5.27it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 355.0254
Val Loss: 352.2613
Epoch 11 of 20


27it [00:18,  1.47it/s]                                                                                                
5it [00:00,  5.27it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.8195
Val Loss: 352.1491
Epoch 12 of 20


27it [00:18,  1.44it/s]                                                                                                
5it [00:00,  5.26it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.6538
Val Loss: 351.8351
Epoch 13 of 20


27it [00:19,  1.41it/s]                                                                                                
5it [00:01,  4.71it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.4761
Val Loss: 351.7328
Epoch 14 of 20


27it [00:18,  1.47it/s]                                                                                                
5it [00:01,  4.80it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.4059
Val Loss: 351.6214
Epoch 15 of 20


27it [00:18,  1.43it/s]                                                                                                
5it [00:00,  5.59it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.4007
Val Loss: 351.7551
Epoch 16 of 20


27it [00:18,  1.46it/s]                                                                                                
5it [00:01,  4.14it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.4228
Val Loss: 351.9450
Epoch 17 of 20


27it [00:18,  1.44it/s]                                                                                                
5it [00:01,  5.00it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.4009
Val Loss: 352.1975
Epoch 18 of 20


27it [00:18,  1.45it/s]                                                                                                
5it [00:00,  5.43it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.3984
Val Loss: 352.0419
Epoch 19 of 20


27it [00:18,  1.46it/s]                                                                                                
5it [00:01,  3.85it/s]                                                                                                 
  0%|                                                                                           | 0/26 [00:00<?, ?it/s]

Train Loss: 354.2363
Val Loss: 351.6248
Epoch 20 of 20


27it [00:18,  1.47it/s]                                                                                                
5it [00:01,  4.68it/s]                                                                                                 

Train Loss: 353.8467
Val Loss: 351.5383



