<a href="https://colab.research.google.com/github/aviax1/AE1/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
!pip install wandb

**dependencies**

In [30]:
# used snniped from https://github.com/L1aoXingyu/pytorch-beginner/
import torch
import wandb
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from tensorflow.keras.datasets import mnist
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

**initial**

In [31]:

(xtrain,ytrain), (xtest,ytest) = mnist.load_data()
num_epochs=200      #
batch_size = 32     #
image_size=784      #
hidden_size=64     #
lv_size = 64        # Latent Variable 
learning_rate=1e-5  #
cret = nn.MSELoss() # criterion

**build model**

In [33]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(image_size, image_size),nn.ReLU(True), nn.Linear(image_size, hidden_size),
            nn.ReLU(True), nn.Linear(hidden_size, hidden_size), nn.ReLU(True), nn.Linear(hidden_size, lv_size))
        self.decoder = nn.Sequential(
            nn.Linear(lv_size, hidden_size),nn.ReLU(True),nn.Linear(hidden_size, hidden_size),
            nn.ReLU(True),nn.Linear(hidden_size, hidden_size),nn.ReLU(True), nn.Linear(hidden_size, image_size), nn.Tanh())

    def forward(self, x):
        return self.decoder(self.encoder(x))

**model setting**

In [34]:
model = autoencoder()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
wandb.watch(model)
class DigitDataSet(Dataset):
  def __init__(self, dataset):
      self.dataset = dataset
      self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      if torch.is_tensor(idx):
          idx = idx.tolist()
      return self.transform( self.dataset[idx,:,:])

def images_row(dis_images,title,add_to_index=0, images_in_row=5):
  if( len(np.shape(dis_images)) == 2):
    dis_images=dis_images[0:images_in_row,:]
  else:
    dis_images=dis_images[0:images_in_row,0,:,:]
  for i in range(len(dis_images)):
    ax = plt.subplot(30, images_in_row, i+add_to_index + 1)
    plt.imshow(dis_images[i].reshape(28, 28))
    plt.title( title)
    plt.gray()

def visual_epoch(epoch_num,model,dataloader):
  for data in dataloader:
      input_imgs = data
      imgs = Variable(input_imgs.view(input_imgs.size(0), -1))
      output_imgs = model(imgs)
      images_row(input_imgs,"org ",5*epoch_num)
      images_row(output_imgs.detach().numpy(),"rec ",5*(epoch_num+1))

**train model by digit**

In [37]:
def train_by_digit(by_digit):
  wandb.init()
  print("*****\nstart traning Model for digit " +str(by_digit) +"\n")
  dataloader = DataLoader(DigitDataSet(xtrain[ytrain==by_digit]), batch_size=batch_size,shuffle=True, num_workers=4)
  visual_counter=0
  for epoch in range(num_epochs):
    for data in dataloader:
      input_imgs = data
      imgs = Variable(input_imgs.view(input_imgs.size(0), -1))
      output_imgs = model(imgs)
      loss = cret(output_imgs, imgs)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    #if epoch%3==0 :
    #  visual_epoch(visual_counter,model,dataloader)
    print('epoch [{}/{}], loss:{:.4f}' .format(epoch + 1, num_epochs, loss.data))
    wandb.log({"loss": loss.data})

  torch.save(model.state_dict(), './ae_'+str(by_digit)+'.pth')
  print("\nfinish traning Model Number " +str(by_digit) +"\n")
  print("*****\n")

for by_digit in range(2,10):
  train_by_digit(by_digit)

*****
start traning Model for digit 2

epoch [1/200], loss:0.1139
epoch [2/200], loss:0.0947
epoch [3/200], loss:0.1063
epoch [4/200], loss:0.1032
epoch [5/200], loss:0.1023
epoch [6/200], loss:0.0942
epoch [7/200], loss:0.1022
epoch [8/200], loss:0.0956
epoch [9/200], loss:0.1059
epoch [10/200], loss:0.0883
epoch [11/200], loss:0.0968
epoch [12/200], loss:0.0910
epoch [13/200], loss:0.0918
epoch [14/200], loss:0.0876
epoch [15/200], loss:0.1011
epoch [16/200], loss:0.0798
epoch [17/200], loss:0.0926
epoch [18/200], loss:0.0922
epoch [19/200], loss:0.0943
epoch [20/200], loss:0.0820
epoch [21/200], loss:0.0919
epoch [22/200], loss:0.0982
epoch [23/200], loss:0.0978
epoch [24/200], loss:0.0891
epoch [25/200], loss:0.0906
epoch [26/200], loss:0.0957
epoch [27/200], loss:0.0807
epoch [28/200], loss:0.0861
epoch [29/200], loss:0.0897
epoch [30/200], loss:0.0943
epoch [31/200], loss:0.0906
epoch [32/200], loss:0.0916
epoch [33/200], loss:0.0983
epoch [34/200], loss:0.0770
epoch [35/200], lo

KeyboardInterrupt: ignored