In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from zipfile import ZipFile
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
import random
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
import torchvision
import zipfile
from IPython import display
from torch.autograd import Function

In [None]:
files = []
# r=root, d=directories, f = files
for r, d, f in os.walk('../input'):
    for file in f:
        if '.png' in file:
            files.append(os.path.join(r, file))

In [None]:
class myDataset(Dataset):
    def __init__(self,path = '../input'):
        super(myDataset, self).__init__()
        self.files = []
        for r, d, f in os.walk(path):
            for file in f:
                if '.png' in file:
                    self.files.append(os.path.join(r, file))

    def __getitem__(self,idx):
        img = Image.open(self.files[idx])
        return self.transform(img)
    def __len__(self):
        return len(self.files)
    def transform(self,img):
        if random.random()>0.5:
            angle = random.randint(-60, 60)
            img = TF.rotate(img,angle)
        width, height = img.size
        dw = 32 - (width%32)
        dh = 32 - (height%32)
        img = TF.pad(img,(dw,dh,0,0))
        return TF.to_tensor(img)

In [None]:
ds = myDataset()

In [None]:
ds[3]

In [None]:
class SignFunction(Function):
    def __init__(self):
        super(SignFunction,self).__init__()
    @staticmethod
    def forward(ctx,input, is_training=True):
        if is_training:
            prob = input.new(input.size()).uniform_()
            x = input.clone()
            x[(1 - input) / 2 <= prob] = 1
            x[(1 - input) / 2 > prob] = -1
            return x
        else:
            return input.sign()
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None
        
class Sign(nn.Module):
    def __init__(self):
        super(Sign, self).__init__()
    def forward(self,x):
        return SignFunction.apply(x, self.training)
class Binarizer(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Binarizer,self).__init__()
        self.sign = Sign()
        self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size=1,bias=False)
    def forward(self,x):
        x = self.conv1(x)
        x =  F.tanh(x)
        return self.sign(x)

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder,self).__init__()
        self.enc = nn.Sequential(nn.Conv2d(3,32,8,stride=4,padding=2),
                                nn.ReLU(),
                                nn.BatchNorm2d(32),
                                nn.Conv2d(32,64,2,stride=2),
                                nn.ReLU(),
                                nn.BatchNorm2d(64),
                                )
        self.dec = nn.Sequential(nn.ConvTranspose2d(128,32,8,stride=4, padding=2),
                                nn.BatchNorm2d(32),
                                nn.ReLU(),
                                nn.ConvTranspose2d(32,3,2,2),
                                nn.BatchNorm2d(3),
                                nn.ReLU(),
                                )
        self.binarizer = Binarizer(64,128)
    def forward(self,x):
    
        x = self.enc(x)
        x = self.binarizer(x)
        x = self.dec(x)
        return x

In [None]:
batch_size = 1
validation_split = 0.1
shuffle_dataset = True
random_seed= 42

ds = myDataset()

# Creating data indices for training and validation splits:
dataset_size = len(ds)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]


# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(ds,batch_size=batch_size,sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size,sampler=validation_sampler)

In [None]:
model=autoencoder().float()
criterion =  nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
exp_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,80,100], gamma=0.1)
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
def train(epoch=1):
    while epoch<=stop_epoch:
        total_loss = 0
        total_accuracy = 0
        model.train()
        exp_lr_scheduler.step()
        print('Epoch: {}\tLR: {:.5f}'.format(epoch,exp_lr_scheduler.get_lr()[0]))
        for batch_idx, data in enumerate(train_loader):
          target = data
          if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
          # forward
          output = model(data)
    #       print(output.shape)
    #       print(data.shape)
          # backward + optimize
          loss = criterion(output, target)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          # print statistics
          accuracy = 0
          total_accuracy+=accuracy
          total_loss+=loss
    #       print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.5f}'.format(epoch, (batch_idx + 1) * len(data), len(train_indices),100*(batch_idx + 1)* len(data) / len(train_indices), loss))
        print('Train Loss: \t'+str(total_loss*batch_size/len(train_indices)))
        vloss, vaccuracy = validate()
        train_losses.append((total_loss*batch_size)/len(train_indices))
        val_losses.append((vloss*batch_size)/len(val_indices))
        train_accuracy.append((total_accuracy*batch_size)/len(train_indices))
        val_accuracy.append((vaccuracy*batch_size)/len(val_indices))
        epoch_data.append(epoch)
        visualize()
        epoch=1+epoch

In [None]:
def validate():
  total_loss = 0
  total_acc = 0
  model.train()
  for batch_idx, data in enumerate(validation_loader):
    target = data
    if torch.cuda.is_available():
      data = data.cuda()
      target = target.cuda()
    output = model(data)
    loss = criterion(output, target).item()
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
    total_loss+=loss
    accuracy = 0
    total_acc+=accuracy
  return total_loss,total_acc
def visualize():
  plt.figure(figsize=(15,7))
  plt.plot(epoch_data, train_losses,label="Train Loss {:.5f}".format(train_losses[-1]))
  plt.plot(epoch_data,val_losses, label="Validation Loss {:.5f}".format(val_losses[-1]))
  plt.plot(epoch_data, train_accuracy,label="Train Accuracy {:.5f}".format(train_accuracy[-1]))
  plt.plot(epoch_data,val_accuracy, label="Validation Accuracy {:.5f}".format(val_accuracy[-1]))
  display.clear_output(wait=False)
  plt.legend()
  plt.show()

In [None]:
train_losses = []
val_losses = []
epoch_data = []
train_accuracy = []
val_accuracy = []
stop_epoch = 52
start = time.time()
train()
end = time.time()
print(end-start)