<a href="https://colab.research.google.com/github/aviax1/AE1/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb

**dependencies**

In [None]:
# used snniped from https://github.com/L1aoXingyu/pytorch-beginner/
import torch
import wandb
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from tensorflow.keras.datasets import mnist
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from multiprocessing import Process

**initial**

In [None]:
(xtrain,ytrain), (xtest,ytest) = mnist.load_data()
num_epochs=130      #
batch_size = 8     #
image_size=784      #
hidden_size=72     #
lv_size = 64        # Latent Variable 
learning_rate=1e-4  #
cret = nn.MSELoss() # criterion

**build model**

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(image_size, hidden_size),   #nn.ReLU(True), nn.Linear(image_size, hidden_size),nn.ReLU(True), nn.Linear(hidden_size, hidden_size),
            nn.ReLU(True), nn.Linear(hidden_size, hidden_size), nn.ReLU(True), nn.Linear(hidden_size, lv_size))
        self.decoder = nn.Sequential(
            nn.Linear(lv_size, hidden_size),nn.ReLU(True),nn.Linear(hidden_size, hidden_size),nn.ReLU(True), nn.Linear(hidden_size, hidden_size),
            nn.ReLU(True),nn.Linear(hidden_size, hidden_size),nn.ReLU(True), nn.Linear(hidden_size, image_size), nn.Tanh())

    def forward(self, x):
        return self.decoder(self.encoder(x))

**model setting**

In [None]:
model = autoencoder()
#wandb.watch(model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

class DigitDataSet(Dataset):
  def __init__(self, dataset):
      self.dataset = dataset
      self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      if torch.is_tensor(idx):
          idx = idx.tolist()
      return self.transform( self.dataset[idx,:,:])

def images_row(dis_images,title,add_to_index=0, images_in_row=5):
  if( len(np.shape(dis_images)) == 2):
    dis_images=dis_images[0:images_in_row,:]
  else:
    dis_images=dis_images[0:images_in_row,0,:,:]
  for i in range(len(dis_images)):
    ax = plt.subplot(30, images_in_row, i+add_to_index + 1)
    plt.imshow(dis_images[i].reshape(28, 28))
    plt.title( title)
    plt.gray()

def visual_epoch(epoch_num,model,dataloader):
  for data in dataloader:
      input_imgs = data
      imgs = Variable(input_imgs.view(input_imgs.size(0), -1))
      output_imgs = model(imgs)
      images_row(input_imgs,"org ",5*epoch_num)
      images_row(output_imgs.detach().numpy(),"rec ",5*(epoch_num+1))

**train model by digit**

In [None]:
def train_by_digit(by_digit,model):
  wandb.init()
  print("*****\nstart traning Model for digit " +str(by_digit) +"\n")
  dataloader = DataLoader(DigitDataSet(xtrain[ytrain==by_digit]), batch_size=batch_size,shuffle=True, num_workers=4)
  visual_counter=0
  for epoch in range(num_epochs):
    run = 1
    for data in dataloader:
      if run:
        run=0
        input_imgs = data
        imgs = Variable(input_imgs.view(input_imgs.size(0), -1))
        im=data[0,0,:,:].reshape(28,28)
        
        wandb.log({"img": [wandb.Image(im, caption="original")]})

    for data in dataloader:
      input_imgs = data
      imgs = Variable(input_imgs.view(input_imgs.size(0), -1))
      output_imgs = model(imgs)
      loss = cret(output_imgs, imgs)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
    #if epoch%3==0 :
    #  visual_epoch(visual_counter,model,dataloader)
    print('epoch [{}/{}], loss:{:.4f}' .format(epoch + 1, num_epochs, loss.data))
    wandb.log({"loss": loss.data})

  torch.save(model.state_dict(), './ae_'+str(by_digit)+'.pth')
  print("\nfinish traning Model Number " +str(by_digit) +"\n")
  print("*****\n")

for by_digit in range(2,10):
  train_by_digit(by_digit,model)
  

*****
start traning Model for digit 5

epoch [1/130], loss:0.0784
epoch [2/130], loss:0.0490
epoch [3/130], loss:0.0469
epoch [4/130], loss:0.0703
epoch [5/130], loss:0.0561
epoch [6/130], loss:0.0658
epoch [7/130], loss:0.0570
epoch [8/130], loss:0.0351
epoch [9/130], loss:0.0633
epoch [10/130], loss:0.0463
epoch [11/130], loss:0.0600
epoch [12/130], loss:0.0476
epoch [13/130], loss:0.0532
epoch [14/130], loss:0.0381
epoch [15/130], loss:0.0672
epoch [16/130], loss:0.0449
epoch [17/130], loss:0.0565
epoch [18/130], loss:0.0533
epoch [19/130], loss:0.0389
epoch [20/130], loss:0.0576
epoch [21/130], loss:0.0520
epoch [22/130], loss:0.0544
epoch [23/130], loss:0.0548
epoch [24/130], loss:0.0655
epoch [25/130], loss:0.0510
epoch [26/130], loss:0.0457
