In [1]:
import os
import numpy as np 
import random
import warnings
import time
warnings.filterwarnings('ignore')

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


!pip install pytorch-model-summary
from pytorch_model_summary import summary as summary

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)
set_seed(0)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
path = '/content/drive/My Drive/hw10'
path = './'




# Preparation

## Data

In [2]:
class ImgDataset(data.Dataset):

    def __init__(self, data, mode):

        if mode == 'cnn':
            data = data.transpose([0,3,1,2])
        else:
            data = data.reshape(len(data), -1)

        self.data = data
        self.mode = mode

    def __getitem__(self, i):
        return self.data[i]

    def __len__(self):
        return len(self.data)
    
def get_dataloader(dataset, mode = 'train', batch_size = 128):
    shuffle  = True if mode == 'train' else False
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return loader

In [0]:
train = np.load(os.path.join(path, 'train.npy'), allow_pickle=True)
test = np.load(os.path.join(path, 'test.npy'), allow_pickle=True)

## Model

In [0]:
class fcn_autoencoder(nn.Module):
    def __init__(self):
        super(fcn_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(32 * 32 * 3, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 32 * 32 * 3
            ), nn.Tanh())

    def forward(self, x):
        code = self.encoder(x)
        x_rec = self.decoder(code)
        return code, x_rec

In [0]:
class conv_autoencoder(nn.Module):
    def __init__(self):
        super(conv_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			      nn.Conv2d(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
    # 			nn.Conv2d(48, 96, 4, stride=2, padding=1),           # [batch, 96, 2, 2]
    #       nn.ReLU(),
        )
        self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1),  # [batch, 48, 4, 4]
#             nn.ReLU(),
			      nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			      nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Tanh(),
        )

    def forward(self, x):
        # print(x.dtype)
        code = self.encoder(x)
        x_rec = self.decoder(code)
        return code,x_rec

In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(32*32*3, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 32*32*3)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = torch.exp(logvar * 0.5)
        e = torch.tensor(np.random.normal(size = std.size())).to(device, dtype=torch.float)
        return std * e + mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        # print(z.dtype)
        x_rec = self.decode(z)
        return  (mu, logvar), (x_rec, mu, logvar)

In [9]:
model = VAE().to(device)
print(summary(model, torch.zeros((1, 3*32*32)).to(device), show_hierarchical=True))

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1            [1, 400]       1,229,200       1,229,200
          Linear-2             [1, 20]           8,020           8,020
          Linear-3             [1, 20]           8,020           8,020
          Linear-4            [1, 400]           8,400           8,400
          Linear-5           [1, 3072]       1,231,872       1,231,872
Total params: 2,485,512
Trainable params: 2,485,512
Non-trainable params: 0
-----------------------------------------------------------------------



VAE(
  (fc1): Linear(in_features=3072, out_features=400, bias=True), 1,229,200 params
  (fc21): Linear(in_features=400, out_features=20, bias=True), 8,020 params
  (fc22): Linear(in_features=400, out_features=20, bias=True), 8,020 params
  (fc3): Linear(in_features=20, out_features=400, bias=True), 8,400 params
  (fc4): Linear(in_features=400, out_f

## Train

In [0]:
def loss_vae(output, x):
    """
    rec_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    rec_x, mu, logvar = output
    mse = nn.MSELoss(reduction = 'sum')(rec_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = torch.sum(logvar.exp() - (1 + logvar) + (mu**2)) * (0.5)
    # KLD = torch.sum(-(logvar.exp()+(mu**2)) + (1 + logvar) ) * (-0.5)
    return mse + KLD


In [0]:
def run_epoch(model, dataloader, criterion, optimizer, mode, best_loss = np.inf):
    epoch_loss = 0
    for x_batch in dataloader:
        x_batch = x_batch.to(device, dtype = torch.float)
        # ===================forward=====================
        code, output = model(x_batch)
        # if model_type == 'vae':
        # loss = loss_vae(output[0], img, output[1], output[2], criterion)
        # else:
        loss = criterion(output, x_batch)
        epoch_loss += loss.item()
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # ===================save====================
        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(model, os.path.join(path, 'best_model_{}.pt'.format(mode)))
    # ===================log========================
    return epoch_loss


In [0]:
def train_process(mode, learning_rate, batch_size, num_epochs):
    set_seed(0)
    train_set = ImgDataset(train, mode = mode)
    loader = get_dataloader(train_set, mode = mode, batch_size=batch_size)
    model_classes = {'fcn':fcn_autoencoder(), 'cnn':conv_autoencoder(), 'vae':VAE()}
    model = model_classes[mode].to(device)
    criterion = nn.MSELoss(reduction = 'sum') if mode != 'vae' else loss_vae
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    for epoch in range(1, num_epochs+1):
        st = time.time()
        model.train()
        epoch_loss = run_epoch(model, loader, criterion, optimizer, mode)
        ed = time.time()
        print('{:.2f}s,epoch [{:0>3d}/{}], loss: {:.8f}'
            .format(ed-st, epoch, num_epochs, epoch_loss/len(train_set)))
    return model

In [0]:
# cnn
num_epochs = 500
batch_size = 128
learning_rate = 1e-3
mode = 'cnn'
model = train_process(mode, learning_rate, batch_size, num_epochs)

4.89s,epoch [001/500], loss: 238.67313157
3.75s,epoch [002/500], loss: 94.70521890
3.49s,epoch [003/500], loss: 70.19295094
3.54s,epoch [004/500], loss: 57.87036293
3.46s,epoch [005/500], loss: 49.02686215
3.46s,epoch [006/500], loss: 44.22183844
3.40s,epoch [007/500], loss: 41.35901985
3.40s,epoch [008/500], loss: 39.40414973
3.45s,epoch [009/500], loss: 37.87608655
3.49s,epoch [010/500], loss: 37.06621351
3.42s,epoch [011/500], loss: 35.97624120
3.47s,epoch [012/500], loss: 35.14260513
3.50s,epoch [013/500], loss: 34.49964633
3.43s,epoch [014/500], loss: 33.97726877
3.42s,epoch [015/500], loss: 33.16445159
3.44s,epoch [016/500], loss: 32.69546522
3.38s,epoch [017/500], loss: 31.58471420
3.48s,epoch [018/500], loss: 30.75670291
3.54s,epoch [019/500], loss: 30.39430333
3.46s,epoch [020/500], loss: 29.45884884
3.48s,epoch [021/500], loss: 28.43016223
3.44s,epoch [022/500], loss: 27.63918558
3.48s,epoch [023/500], loss: 26.75585568
3.46s,epoch [024/500], loss: 25.87029265
3.44s,epoch [02