<a href="https://colab.research.google.com/github/CastleJin/2021_01_12_winter_internship_magnet/blob/main/magnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Set** **Prerequisite**

In [None]:
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install Pillow==7.0.0

In [None]:
!python -V

In [None]:
!pip install tqdm

**MODEL**

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# modules
class res_blk(nn.Module):
  def __init__(self, layer_dims, ks, s):
    super(res_blk, self).__init__()
    p = int((ks - 1) / 2)
    self.conv1 = nn.Conv2d(layer_dims, layer_dims, kernel_size=ks, stride=s, padding=p, padding_mode='reflect', bias=False)
    self.activation = nn.ReLU()

  def forward(self, input):
    out = self.conv1(input)
    out = self.activation(out)
    out = self.conv1(out)
    return input + out

def multi_res_blk(num_res_blk, layer_dims, ks, s):
  layers = []
  for i in range(num_res_blk):
    layers.append(res_blk(layer_dims, ks, s))
  return nn.Sequential(*layers)

class res_manipulator(nn.Module):
  def __init__(self, layer_dims=32):
    super(res_manipulator, self).__init__()
    self.conv1 = nn.Conv2d(layer_dims, layer_dims, kernel_size=7, stride=1, padding=3, padding_mode='reflect', bias=False)
    self.conv2 = nn.Conv2d(layer_dims, layer_dims, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) 
    self.residual = multi_res_blk(1, layer_dims, 3, 1)
    self.activation = nn.ReLU()

  def forward(self, enc_a, enc_b, amp_factor):
    out = enc_b - enc_a
    out = self.activation(self.conv1(out))
    out *= amp_factor
    out = self.conv2(out)
    out = self.residual(out)
    return enc_b + out

class res_encoder(nn.Module):
  def __init__(self, layer_dims=32, num_res_blk=3):
    super(res_encoder, self).__init__()
    self.conv1 = nn.Conv2d(3, int(layer_dims / 2), kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect', bias=False)
    self.conv2 = nn.Conv2d(int(layer_dims / 2), layer_dims, kernel_size = 3, stride = 2, padding = 1, padding_mode = 'reflect',bias=False)
    self.residual = multi_res_blk(num_res_blk, layer_dims, 3, 1)
    self.activation = nn.ReLU()

  def forward(self, x):
    out = self.activation(self.conv1(x))
    out = self.activation(self.conv2(out))
    out = self.residual(out)
    return out

class res_decoder(nn.Module):
  def __init__(self, layer_dims=64, num_res_blk=9):
    super(res_decoder, self).__init__()
    self.residual = multi_res_blk(num_res_blk, layer_dims, 3, 1)
    self.up_sample = nn.Upsample(scale_factor = 2, mode = 'nearest')
    self.conv1 = nn.Conv2d(layer_dims, int(layer_dims / 2), kernel_size = 3, stride = 1, padding = 1, padding_mode = 'reflect', bias=False) ## change
    self.conv2 = nn.Conv2d(int(layer_dims / 2), 3, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect', bias=False) ## change
    self.activation = nn.ReLU()

  def forward(self, x):
    out = self.residual(x)
    out = self.up_sample(out)
    out = self.activation(self.conv1(out))
    out = self.conv2(out)
    return out

# magnet
class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()
    # set variables
    self.num_enc_resblk = 3
    self.res_enc_dim = 32
    self.num_texture_resblk = 2
    self.num_shape_resblk = 2

    # set arch
    self.res_encoder = res_encoder(self.res_enc_dim ,self.num_enc_resblk)
    self.conv_tex = nn.Conv2d(self.res_enc_dim, self.res_enc_dim, kernel_size = 3, stride = 2, padding = 1, padding_mode = 'reflect', bias=False) # stride is 2, cause texture_downsample is True, else 1
    self.conv_sha = nn.Conv2d(self.res_enc_dim, self.res_enc_dim, kernel_size = 3, stride = 1, padding = 1, padding_mode = 'reflect', bias=False)
    self.texture_resblk = multi_res_blk(self.num_texture_resblk, self.res_enc_dim, 3, 1)
    self.shape_resblk = multi_res_blk(self.num_shape_resblk, self.res_enc_dim, 3, 1)
    self.activation = nn.ReLU()

  def forward(self, img):
    enc = self.res_encoder(img)
    texture_enc = enc
    shape_enc = enc
    texture_enc = self.activation(self.conv_tex(texture_enc))
    texture_enc = self.texture_resblk(texture_enc)
    shape_enc = self.activation(self.conv_sha(shape_enc))
    shape_enc = self.shape_resblk(shape_enc)
    return texture_enc, shape_enc

class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()
    # set variables
    self.num_dec_resblk = 9
    self.texture_dims = 32
    self.shape_dims = 32
    self.decoder_dims = self.texture_dims + self.shape_dims
    
    # set arch
    self.up_sample = nn.Upsample(scale_factor = 2, mode = 'nearest') # texture가 downsampling 됐을 때 activate한다.
    self.conv_tex_aft_upsample = nn.Conv2d(self.texture_dims, self.texture_dims, kernel_size = 3, stride = 1, padding = 1, padding_mode = 'reflect', bias=False)
    self.res_decoder = res_decoder(self.decoder_dims, self.num_dec_resblk)
    self.activation = nn.ReLU()
  
  def forward(self, texture_enc, shape_enc):
    texture_enc = self.up_sample(texture_enc) # texture가 downsampling 됐을 때 activate한다.
    texture_enc = self.activation(self.conv_tex_aft_upsample(texture_enc))
    enc = torch.cat((texture_enc, shape_enc), 1)
    return self.res_decoder(enc)

class magnet(nn.Module):
  def __init__(self):
    super(magnet, self).__init__()
    self.encoder = encoder()
    self.decoder = decoder()
    self.res_manipulator = res_manipulator()

  def forward(self, amplified, image_a, image_b, image_c, amp_factor):
    texture_amp, _ = self.encoder(amplified)
    texture_a, shape_a = self.encoder(image_a)
    texture_b, shape_b = self.encoder(image_b)
    texture_c, shape_c = self.encoder(image_c)
    out_shape_enc = self.res_manipulator(shape_a, shape_b, amp_factor)
    out = self.decoder(texture_b, out_shape_enc)

    return out, texture_a, texture_c, texture_b, texture_amp, shape_b, shape_c


**Dataset & Dataloader Setting**

In [None]:
# References
# 1. https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# 2. https://wikidocs.net/57165
# 3. https://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

import os
from PIL import Image
from __future__ import print_function
import numpy as np
import json
from matplotlib.pyplot import imshow
import random

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data.sampler import Sampler

class three(Dataset):
  def __init__(self, data_path, transform = None):
    self.data_path = data_path # root of dataset
    self.sub_dir = os.listdir(self.data_path) # sub directory
    
    # data path
    self.amplified_path = os.path.join(self.data_path, self.sub_dir[0])
    self.frameA_path = os.path.join(self.data_path, self.sub_dir[1])
    self.frameB_path = os.path.join(self.data_path, self.sub_dir[2])
    self.frameC_path = os.path.join(self.data_path, self.sub_dir[3])
    self.meta_path = os.path.join(self.data_path, self.sub_dir[4])
    self.transform = transform

  
  def _read_json(self, path):
    with open(path,'r') as f:
      json_data = json.load(f)
    return json_data['amplification_factor']

  def __len__(self):
    file_list = os.listdir(self.amplified_path)
    return len(file_list)
  
  def __getitem__(self,idx):
    # subfile list
    amplified_list = os.listdir(self.amplified_path)
    frameA_list = os.listdir(self.frameA_path)
    frameB_list = os.listdir(self.frameB_path)
    frameC_list = os.listdir(self.frameC_path)
    meta_list = os.listdir(self.meta_path)

    # sort
    amplified_list.sort()
    frameA_list.sort()
    frameB_list.sort()
    frameC_list.sort()
    meta_list.sort()
  
    # read image & json
    amplified = Image.open(os.path.join(self.amplified_path, amplified_list[idx]))
    frameA = Image.open(os.path.join(self.frameA_path, frameA_list[idx]))
    frameB = Image.open(os.path.join(self.frameB_path, frameB_list[idx]))
    frameC = Image.open(os.path.join(self.frameC_path, frameC_list[idx]))
    mag_factor = self._read_json(os.path.join(self.meta_path, meta_list[idx]))

    # convert nparray & normalize to -1 to 1
    amplified = np.array(amplified, dtype = 'float32') / 127.5 - 1.0
    frameA = np.array(frameA, dtype = 'float32') / 127.5 - 1.0
    frameB = np.array(frameB, dtype = 'float32') / 127.5 - 1.0
    frameC = np.array(frameC, dtype = 'float32') / 127.5 - 1.0
    mag_factor -= 1.0
    mag_factor = np.array(mag_factor, dtype = 'float32')

    sample = {'amplified': amplified, 'frameA': frameA, 'frameB': frameB, 'frameC': frameC, 'mag_factor': mag_factor}

    if self.transform is not None:
      sample = self.transform(sample)

    return sample

class ToTensor(object):
  def __call__(self, sample):
    amplified, frameA, frameB, frameC, mag_factor = sample['amplified'], sample['frameA'], sample['frameB'], sample['frameC'], sample['mag_factor']
    # swap color axis because
    # numpy image: H x W x C
    # torch image: C X H X W
    amplified = amplified.transpose((2, 0, 1))
    frameA = frameA.transpose((2, 0, 1))
    frameB = frameB.transpose((2, 0, 1))
    frameC = frameC.transpose((2, 0, 1))

    # convert tensor
    amplified = torch.from_numpy(amplified)
    frameA = torch.from_numpy(frameA)
    frameB = torch.from_numpy(frameB)
    frameC = torch.from_numpy(frameC)
    mag_factor = torch.from_numpy(mag_factor)
    mag_factor = mag_factor.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    ToTensor_sample = {'amplified': amplified, 'frameA': frameA, 'frameB': frameB, 'frameC': frameC, 'mag_factor': mag_factor}
    return ToTensor_sample

class shot_noise(object):
  # This function approximate poisson noise upto 2nd order.
  def __init__(self, n):
    self.n = n

  def _get_shot_noise(self, image):
    n = torch.zeros_like(image).normal_(mean=0.0, std=1.0)
    # strength ~ sqrt image value in 255, divided by 127.5 to convert
    # back to -1, 1 range.
    # 그러나 strength ~ sqrt image value in 255가 이해가 되지 않음.
    # 제곱에 비례해야하는 것이 아닌가?

    n_str = torch.sqrt(torch.as_tensor(image + 1.0)) / torch.sqrt(torch.as_tensor(127.5))
    return torch.mul(n, n_str)

  def _preproc_shot_noise(self, image, n):
    nn = np.random.uniform(0, n)
    return image + nn * self._get_shot_noise(image)

  def __call__(self, sample):
    amplified, frameA, frameB, frameC, mag_factor = sample['amplified'], sample['frameA'], sample['frameB'], sample['frameC'], sample['mag_factor']
    # add shot noise
    frameA = self._preproc_shot_noise(frameA, self.n)
    frameB = self._preproc_shot_noise(frameB, self.n)
    frameC = self._preproc_shot_noise(frameC, self.n)

    preproc_sample = {'amplified': amplified, 'frameA': frameA, 'frameB': frameB, 'frameC': frameC, 'mag_factor': mag_factor}
    return preproc_sample

class num_sampler(Sampler):
# Sampling a specific number of multiple-th indices from data.
  def __init__(self, data, is_val=True, shuffle=False, num=10):
    self.num_samples = len(data)
    self.is_val = is_val
    self.shuffle = shuffle
    self.num = num

  def __iter__(self):
    k = []
    for i in range(self.num_samples):
      if self.is_val: # case of validation dataset
        if i%self.num == self.num-1:
          k.append(i)
      else: # case of train dataset
        if i%self.num != self.num-1:
          k.append(i)

    if self.shuffle:
      random.shuffle(k)
    return iter(k)

  def __len__(self):
    return self.num_samples

In [None]:
# noise
poisson_noise_n = 0.3

# train
train_batch_size = 1
val_batch_size = 1

# load dataset
train_dataset = three(data_path = '/content/drive/MyDrive/train', transform = transforms.Compose([ToTensor(), shot_noise(poisson_noise_n)]))
val_dataset = three(data_path = '/content/drive/MyDrive/train', transform = transforms.Compose([ToTensor()]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, sampler=num_sampler(train_dataset, is_val=False, shuffle=True))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch_size, sampler=num_sampler(val_dataset))
train_batch = len(train_loader)
one_epoch_size = 450
val_size = 50

In [None]:
"""
# test one image
i = 0
sample = dataset[i]
print(i, sample['amplified'].shape, sample['mag_factor'], sample['mag_factor'].shape)
a = sample['frameA']
print(a)
a= (a+1)*127.5
a = a /255
b= transforms.ToPILImage()(a)
imshow(b)

# check the batch dataset size
for i_batch, sample_batched in enumerate(data_loader):
    print(i_batch, sample_batched['amplified'].size(), sample_batched['mag_factor'].size())
"""

**Hyperprameters & Device Setting**

In [None]:
# for exponential decay
decay_steps = 3000
lr_decay = 1.0

# for Adam
betal = 0.9


**Loss &** **Optimizer**

In [None]:
import torch
from tqdm import tqdm

# set variable
PATH = '/content/drive/MyDrive/model'
load_num = 100
load_name = '/epoch_{}'.format(load_num)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
is_load = False
num_epoch = 400
tex_loss_w = 1.0
sha_loss_w = 1.0

# load model
if is_load:
  model = TheModelClass(*args, **kwargs)
  optimizer = TheOptimizerClass(*args, **kwargs)

  checkpoint = torch.load(PATH+load_name)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  num_epoch -= checkpoint['epoch']
  train_losses = checkpoint['train_loss']
  val_losses = checkpoint['val_loss']

# initialize model, criterion
else:
  model = magnet().to(device)
  criterion = torch.nn.L1Loss()
  optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001, betas = (betal,0.999), weight_decay = 0, amsgrad=False)
  train_losses = []
  val_losses = []


**Training & Evaluation**

In [None]:
with open('/content/drive/MyDrive/Colab Notebooks/train_loss.txt', 'w') as f1:  
  with open('/content/drive/MyDrive/Colab Notebooks/val_loss.txt', 'w') as f2:
    for epoch in tqdm(range(num_epoch)):
        running_loss = 0.0
        model.train()
        
        # train
        for i, sample in enumerate(train_loader):
            amplified, frameA, frameB, frameC, amp_factor = sample['amplified'].to(device), sample['frameA'].to(device), sample['frameB'].to(device), sample['frameC'].to(device), sample['mag_factor'].to(device)
            optimizer.zero_grad()
            Y, Va, Vb, _, _, Mb, Mb_ = model(amplified, frameA, frameB, frameC, amp_factor)
            loss = criterion(Y, amplified) + tex_loss_w * criterion(Va, Vb) + sha_loss_w * criterion(Mb, Mb_)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # evaluation
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for k, sample in enumerate(val_loader):
                val_amp, val_A, val_B, val_C, val_factor = sample['amplified'].to(device), sample['frameA'].to(device), sample['frameB'].to(device), sample['frameC'].to(device), sample['mag_factor'].to(device)
                Y, Va, Vb, _, _, Mb, Mb_ = model(val_amp, val_A, val_B, val_C, val_factor)
                loss = criterion(Y, val_amp) + tex_loss_w * criterion(Va, Vb) + sha_loss_w * criterion(Mb, Mb_)
                val_loss += loss.item()
        
        # result
        print('[epoch: %d] train_loss: %.3f, val_loss: %.3f'% (epoch + 1, running_loss / one_epoch_size, val_loss / val_size))
        f1.write('%.3f\n' % (running_loss / one_epoch_size))
        f2.write('%.3f\n' % (val_loss / val_size))
        train_losses.append(running_loss / one_epoch_size)
        val_losses.append(val_loss / val_size)

        # save
        torch.save({'epoch': epoch+1, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'train_loss':train_losses, 
                    'val_loss':val_losses}, 
                    PATH+'/epoch_{}'.format(epoch+1))

        # for re-train
        torch.cuda.empty_cache()

In [None]:
from matplotlib import pyplot as plt
import numpy as np

#Evaluation
iters = range(0, epoch)
print(len(train_losses))
plt.plot(iters, train_losses, 'k', label='Training loss')
plt.plot(iters, val_losses, 'b', label = 'Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('iters')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:

"""
import torch
import torch.nn.functional as F
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler

# Dataset 상속
class CustomDataset(Dataset): 
  def __init__(self):
    self.x_data = [[1], [2], [3], [4], [5], [6], [7], [8], [9],[10], [11], [12],[13], [14], [15]]

  # 총 데이터의 개수를 리턴
  def __len__(self): 
    return len(self.x_data)

  # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
  def __getitem__(self, idx): 
    x = torch.FloatTensor(self.x_data[idx])
    return x

class num_sampler(Sampler):
# Sampling a specific number of multiple-th indices from data.
  def __init__(self, data, is_val=True, shuffle=False, num=10):
    self.num_samples = len(data)
    self.is_val = is_val
    self.shuffle = shuffle
    self.num = num

  def __iter__(self):
    k = []
    for i in range(self.num_samples):
      if self.is_val: # case of validation dataset
        if i%self.num == self.num-1:
          k.append(i)
      else: # case of train dataset
        if i%self.num != self.num-1:
          k.append(i)

    if self.shuffle:
      random.shuffle(k)
    return iter(k)

  def __len__(self):
    return self.num_samples


dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, sampler=DummySampler(dataset))

# check the batch dataset size
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched)
print('\n')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched)
    """