In [31]:
import torch
import torch.nn as nn
import numpy as np
import math
from torch.autograd import Variable
import torchvision
import torchvision.models as models
from torchvision import transforms
import torch.nn.functional as F
from torchvision import datasets
import argparse
import os
import matplotlib.pyplot as plt
import time
import codecs
import csv

## Slecting random seed of Torch and Numpy respectively

In [32]:
manualSeed = 144

np.random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)

## Check if we have a CUDA-capable device; if so, use it

In [33]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Will train on {}'.format(device))

Will train on cpu


In [34]:
def show_img(x):
    if x.size(0) == 1:   
        new_img = torch.cat([x, x, x], dim=0)
        img = new_img.permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(img, cmap ='binary')
    if x.size(0) == 3:
        img = img.permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(img)

## Encoder

In [35]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.block0 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.GELU(),
            # nn.MaxPool2d(2, 2),
            )
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.GELU(),
            # nn.MaxPool2d(2, 2),
            )
        
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            # nn.MaxPool2d(2, 2),
            )
        
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.GELU(),
            # nn.MaxPool2d(2, 2),
            )

    def forward(self, x):
        out = self.block0(x)
        out = self.block1(out)
        out = self.block2(out)
        representation = self.block3(out)
        
        return representation

In [36]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.block0 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            # nn.Upsample(scale_factor=2, mode='nearest'),
            )
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.GELU(),
            # nn.Upsample(scale_factor=2, mode='nearest'),
            )
        
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.GELU(),
            # nn.Upsample(scale_factor=2, mode='nearest'),
            )
        
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1),
            nn.GELU(),
            # nn.Upsample(scale_factor=2, mode='nearest'),
            )

    def forward(self, repres):
        out = self.block0(repres)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        
        return out

## Loss function

In [37]:
class L2Dis(nn.Module):
    def __init__(self):
        super(L2Dis,self).__init__()

    def forward(self, label, target):
        N = target.size(0)
        size = target.size(2) * target.size(3)
        loss = torch.sum((label-target)**2)/N/size
        return loss

In [38]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.bias.data.fill_(0)

In [39]:
class NetWork(nn.Module):
    def __init__(self):
        super(NetWork,self).__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.loss = L2Dis()
        
    def forward(self, x):
        repres = self.encoder(x)
        img = self.decoder(repres)
        loss = self.loss(img, x.detach())
        return loss, img

## Save checkpoint.

In [40]:
import csv
import os

def save(net, epoch):
    print('Saving..')
    state = {
        'net': net,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7' + '_' + str(epoch), _use_new_zipfile_serialization=False)

## Load model

In [41]:
if os.path.isdir('checkpoint'):
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7_181')
    
    net = checkpoint['net']
    net = net.to(device)
    
    start_epoch = checkpoint['epoch'] + 1
    rng_state = checkpoint['rng_state']
    torch.set_rng_state(rng_state)
    print('==> Finished')
    
else:
    print('==> Building model..')
    start_epoch = 0
    net = NetWork()
    if torch.cuda.device_count()>1:
        print("Let's use", torch.cuda.device_count(), 'GPUs!')
        net = nn.DataParallel(net)
    net = net.to(device)
    net.apply(weights_init)
    print('==> Finished')

==> Building model..
==> Finished


In [42]:
# @torchsnooper.snoop()
def train(epoch, train_data):
  net.train()
  L2loss = 0

  Start = time.time()

  optimizer.zero_grad()

  loss, img = net(train_data)

  loss.mean().backward()
  optimizer.step()
  
  L2loss += loss.mean().item()
  C_loss = L2loss / 500


  End = time.time()
  T = End - Start
  H = T // 3600
  M = (T % 3600) // 60
  S = (T % 3600) % 60
  print('using time is: ', H, 'H', M, 'M', S, 'S' )
  
  print("C_loss: ", C_loss)
  imgSet = [inputs[0], img[0]]
  fig = plt.figure(figsize=(128,32))
  for idx in np.arange(2):
    ax = fig.add_subplot(1, 2, idx+1)
    show_img(imgSet[idx])
  plt.show()

In [43]:
def adjust_learning_Crate(optimizer, epoch, num, Clr, warmup=0):
    lr = Clr*0.5*(1+np.cos(np.pi*(epoch - warmup)/(num - warmup)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [44]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

In [None]:
if __name__ == "__main__":
    n_epochs = 100
    Clr = 0.1

    epochStart = time.time()
    for epoch in range(start_epoch, n_epochs + 1):
      print('Epoch: %d' % epoch)

      adjust_learning_Crate(optimizer, epoch, n_epochs, Clr)
      clr = optimizer.param_groups[0]['lr']
      print('classifier training lr:', clr)
  
      inputs = torch.randn(500, 1, 600, 600)
      index = torch.randperm(inputs.size(0))
      batch_size = 1
      num = 500 // batch_size

      for indic in range(num):
        train_data = inputs[indic*batch_size:(1+indic)*batch_size]
        train(epoch, train_data)
      
      # if epoch%10==9:
      #   save(net, epoch) 

Epoch: 0
classifier training lr: 0.1




using time is:  0.0 H 0.0 M 7.525115013122559 S
C_loss:  0.0031973159313201903
