In [1]:
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

#!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
!pip install torch
import torch



# Stochastic autoencoder 3
This notebook implements part of the [eager model](https://docs.google.com/drawings/d/1czjcBtDQGS8X6bnIbYU4wmFvv1AfZt5wwSRk9oyQGw0/edit). Here we continue where we left off in [part 2](https://colab.research.google.com/drive/1XZvLnmtu4QlHAXGCbmO2TFPTYFBrkfk1#scrollTo=o5DfRieOB3rq&uniqifier=1).

Let's put together a cell with train(), up(), down() functions

## Basics

In [0]:
import math
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import random
from scipy.ndimage.filters import gaussian_filter
from scipy import stats
from scipy.stats import norm
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from skimage.draw import line_aa
%matplotlib inline
plt.style.use('classic')

device = "cuda" if torch.cuda.is_available() else "cpu"

# TODO: Use torch.normal(mean, std=1.0, out=None) 
class NormalDistributionTable(object):
    def __init__(self, resolution, var=0.07, table_resolution=100):
      self.resolution = resolution
      self.var = var
      self.table_resolution = table_resolution
      self.gaussians = torch.tensor([norm.pdf(np.arange(0, 1, 1.0 / self.resolution), mean, self.var) for mean in np.linspace(0, 1, self.table_resolution)])
      self.gaussians = self.gaussians.transpose(0, 1)
      self.gaussians = self.gaussians / self.gaussians.sum(dim=0)
      self.gaussians = self.gaussians.transpose(0, 1)

    def lookup(self, mean):
      assert mean >= 0 and mean <= 1, "mean must be between 0 and 1"
      index = math.floor(mean * self.table_resolution)
      if index == self.table_resolution:
        index = self.table_resolution - 1
      return self.gaussians[index]

    def to_pdf(self, images):
      element_count = np.prod(images.shape)
      images_shape = images.shape
      images_view = images.contiguous().view((element_count,))
      images_pdf = torch.stack([self.lookup(mean.item()) for mean in images_view])
      images_pdf = images_pdf.view(images_shape[:-1] + (images_shape[-1] * self.resolution, ))
      return images_pdf


def generate_images(width, height, count=100):
    images = []
    for _ in range(100):
        image = np.zeros((width, height))
        rr, cc, val = line_aa(random.randint(0, height-1), random.randint(0, width-1), random.randint(0, height-1), random.randint(0, width-1))
        image[rr, cc] = val
        image=gaussian_filter(image, 0.5)
        images.append(image)

    return torch.as_tensor(images).to(device)

def generate_moving_line(width, height, count=100):
  images = []
  for i in range(int(count/2)):
    image = np.zeros((width, height))
    rr, cc, val = line_aa(2, 3-i, width-2, height-1-i)
    image[rr, cc] = val
    image=gaussian_filter(image, 0.5)
    images.append(image)

  for i in range(int(count/2)):
    image = np.zeros((width, height))
    rr, cc, val = line_aa(width-1-i, 2-i, 4-i, height-2-i)
    image[rr, cc] = val
    image=gaussian_filter(image, 0.5)
    images.append(image)

  return torch.as_tensor(images).to(device)

    
def show_image(image, vmin=None, vmax=None, title=None, print_values=False):
    #print("image ", image.shape)
    image = image.cpu().numpy()
    fig, ax1 = plt.subplots(figsize=(20, 8))
    if title:
        plt.title(title)
    #i = image.reshape((height, width))
    #print("i ", i.shape)
    ax1.imshow(image, vmin=vmin, vmax=vmax, interpolation='none', cmap=plt.cm.plasma)
    plt.show()
    if print_values:
        print(image)
        
def sample_from_pdf1(pdf):
    assert pdf.shape == (resolution, )

    pk = pdf.copy()
    xk = np.arange(resolution)
    pk[pk<0] = 0
    sum_pk = sum(pk)
    if sum(pk) > 0:
        pk = pk / sum_pk
        custm = stats.rv_discrete(name='custm', values=(xk, pk))
        value = custm.rvs(size=1) / resolution
        # apply scale (conflates value and confidence!)
        value = value * sum_pk
        return value
    else:
        return [0]

def sample_from_pdf(pdf):
    assert pdf.shape == (resolution, )
    #print("pdf ", pdf)

    sum_pdf = sum(pdf)
    #print("sum_pdf ", sum_pdf)

    if sum_pdf > 0:
        v = random.random()
        #print("v ", v)

        s = 0
        index = 0
        while s < v and index < resolution:
          s += pdf[index] / sum_pdf
          index += 1
          #print("  s ", s)
          #print("  index ", index)
          
        # apply scale (conflates value and confidence!)
        return [(index - 1) * sum_pdf / resolution]
    else:
        return [0]


def sample_from_images__(images__):
    assert len(images__.shape) == 3

    # reshape images__ from (image count, height, width*resolution) into (image count*height*width, resolution)
    s = images__.shape
    flattened_images__ = images__.view(s[0], s[1], int(s[2] / resolution), resolution)
    s = flattened_images__.shape
    flattened_images__ = flattened_images__.view(s[0] * s[1] * s[2], s[3])

    # sample single value from each distributions into (image count*height*width, 1)
    sampled_pixels = torch.Tensor([sample_from_pdf(item.cpu().numpy()) for item in flattened_images__])

    # reshape back into (image count, height, width)
    sampled_images = sampled_pixels.view(s[0], s[1], s[2])

    return sampled_images


def averaged_sample_from_images__(images__, count=10):
    sampled_images = torch.stack([sample_from_images__(images__) for i in range(count)])
    return sampled_images.mean(dim=0)


def aggregate_to_pdf(mu_bar, image_count, samples_per_image, iH, iW, resolution):
  #print("aggregate_to_pdf mu_bar", mu_bar.shape)
  # mu_bar                          (image_count * samples_per_image, iH, iW)
  # mu_bar_per_image                (image_count,  samples_per_image, iH, iW)
  mu_bar = mu_bar.clamp(0, 1)
  mu_bar_per_image = mu_bar.view(image_count, samples_per_image, iH, iW)

  # mu_bar_per_image_flattened      (image_count,  iH,  iW, samples_per_image)
  mu_bar_per_image_flattened = mu_bar_per_image.permute(0, 2, 3, 1).contiguous()
  # mu_bar_per_image_flattened      (image_count * iH * iW, samples_per_image)
  mu_bar_per_image_flattened = mu_bar_per_image_flattened.view(image_count * iH * iW, samples_per_image)


  # mu_bar_flattened__              (image_count * iH * iW, resolution)
  mu_bar_flattened__ = torch.zeros((image_count * iH * iW, resolution))
  assert mu_bar_per_image_flattened.shape[0] == mu_bar_flattened__.shape[0]

  for sample_index in range(samples_per_image):
    #print("mu_bar_per_image_flattened[:, sample_index] ", mu_bar_per_image_flattened[:, sample_index])
    histogram_indices = (mu_bar_per_image_flattened[:, sample_index] * resolution).long().cpu()
    for item_index in range(mu_bar_per_image_flattened.shape[0]): # TODO: Vectorize!
      mu_bar_flattened__[item_index][histogram_indices[item_index]] += 1

  # mu_bar__                        (image_count, iH, iW * resolution)
  mu_bar__ = mu_bar_flattened__.view((image_count, iH, iW,  resolution))
  mu_bar__ = torch.nn.functional.normalize(mu_bar__, p=1, dim=3)
  mu_bar__ = mu_bar__.view(          (image_count, iH, iW * resolution))

  return mu_bar__

# Assume input (samples, feature maps, height, width) and that 
# features maps is a perfect squere, e.g. 9, of an integer 'a', e.g. 3 in this case
# Output (samples, height * a, width * a)
def flatten_feature_maps(f):
    s = f.shape
    f = f.permute(0, 2, 3, 1) # move features to the end
    s = f.shape
    a = int(s[3] ** 0.5)  # feature maps are at pos 3 now that we want to first split into a square of size (a X a)
    assert a * a == s[3], "Feature map count must be a perfect square"
    f = f.view(s[0], s[1], s[2], a, a)
    f = f.permute(0, 1, 3, 2, 4).contiguous() # frame count, height, sqr(features), width, sqr(features)
    s = f.shape
    f = f.view(s[0], s[1] * s[2], s[3] * s[4]) # each point becomes a square of features
    return f
  
# Assume input (samples, height * a, width * a)
# Output (samples, feature maps, height, width)
def unflatten_feature_maps(f, a):
    s = f.shape
    f = f.view(s[0], int(s[1] / a), a, int(s[2] / a), a)
    
    f = f.permute(0, 1, 3, 2, 4).contiguous() # move features to the end
    s = f.shape
    f = f.view(s[0], s[1], s[2], a * a).permute(0, 3, 1, 2)
    return f

class EMA:
  def __init__(self, mu):
    super(EMA, self).__init__()
    self.mu = mu

  def forward(self,x, last_average):
    new_average = self.mu*x + (1-self.mu)*last_average
    return new_average
  

resolution = 10
var = 0.05
normal_distribution_table = NormalDistributionTable(resolution=resolution, var=var)

## Autoencoder

In [0]:
class AutoEncoder(nn.Module):
  def __init__(self, a=3):
    super(AutoEncoder, self).__init__()
    self.a = a
    self.encoder = nn.Sequential(                                             # b, 1, w, h
      nn.Conv2d(1, 2 * a * a, 3, stride=1, padding=1),                        # b, 2 * a * a, w, h
      nn.ReLU(True),
      nn.MaxPool2d(2, stride=2),                                              # b, 2 * a * a, w/2, h/2
      nn.Conv2d(2 * a * a, a * a, 3, stride=1, padding=1),                    # b, a * a, w/2, h/2
      nn.ReLU(True),
      nn.MaxPool2d(2, stride=2),                                              # b, a * a, w/4, h/4
      nn.MaxPool2d(2, stride=2),                                              # b, a * a, w/8, h/8
      nn.Sigmoid(),
    )
    self.decoder = nn.Sequential(
      nn.ConvTranspose2d(a * a, 2 * a * a, 3, stride=2, padding=1, output_padding=1), # b, 2 * a * a, w/4, h/4
      nn.ReLU(True),
      nn.ConvTranspose2d(2 * a * a, 2 * a * a, 3, stride=2, padding=1, output_padding=1), # b, 2 * a * a, w/2, h/2
      nn.ReLU(True),
      nn.ConvTranspose2d(2 * a * a, 1, 3, stride=2, padding=1, output_padding=1),     # b, 1, w, h
      nn.Sigmoid()
    )

    self.encoder_output = None

  def forward(self, x):
    assert x.shape[-1] % 4 == 0, "Width and height must be a multiple of 4"
    x = self.encoder_output = self.encoder(x)
    x = self.decoder(x)
    return x

## Unit

In [0]:
class Unit:
  def __init__(self, unit_index, samples_per_image=25, a=3, resolution=10):
    self.unit_index = unit_index
    self.samples_per_image = samples_per_image
    self.a = a
    self.model = AutoEncoder(a=a).to(device)
    self.ema = EMA(0.5)
    self.image_count = None
    self.image_size = None
    self.resolution = resolution
    self.trained = False

    if os.path.exists(self.save_path()):
      self.model.load_state_dict(torch.load(self.save_path()))
      self.model.eval()
      self.trained = True

  def up(self):
    if self.model.encoder_output is None:
      raise Error("must call train() before up()")
      return
    
    h1 = self.model.encoder_output
    h1_flattened = flatten_feature_maps(h1)
    #print("images        ", images.shape)
    #print("mu1__         ", mu1__.shape)
    #print("h1            ", h1.shape)
    #print("h1_flattened  ", h1_flattened.shape)

    h1__ = normal_distribution_table.to_pdf(h1_flattened)

    last_average = h1__[0].clone()
    for index in range(1, h1__.shape[0]):
      h1__[index] = last_average = self.ema.forward(h1__[index], last_average)

    #print("h1__          ", h1__.shape)
    return h1__
  
  
  def down(self, u2_bar__):
    if self.model.encoder_output is None:
      raise Error("must call train() before down()")
      return

    sampled_h1 = sample_from_images__(u2_bar__)
    #print("sampled_h1             ", sampled_h1.shape)
    unflattened_sampled_h1 = unflatten_feature_maps(sampled_h1, self.a).to(device)
    #print("**unflattened_sampled_h1 ", unflattened_sampled_h1.shape)
    h1 = self.model.encoder_output
    #print("**h1 ", h1.shape)

    #show_image(sampled_h1[0].detach(), title=f"sampled_h1 {0}", vmin=0, vmax=1)
    #show_image(unflattened_sampled_h1[0, 0].detach(), title=f"unflattened_sampled_h1 {0}", vmin=0, vmax=1)
    #show_image(h1[0, 0].detach(), title=f"h1 {0}", vmin=0, vmax=1)


    #merged_h1 = (unflattened_sampled_h1 + h1) / 2.0
    #merged_h1 = unflattened_sampled_h1 * 0.5 + h1 * 0.5
    merged_h1 = unflattened_sampled_h1 * h1
    
    decoded_mu1 = self.model.decoder.forward(merged_h1)
    decoded_mu1 = decoded_mu1[:, 0, :, :]
    #print("decoded_mu1            ", decoded_mu1.shape)
    #print("image_count            ", self.image_count)
    #print("samples_per_image      ", self.samples_per_image)
    #print("image_size             ", self.image_size)

    mu1_bar__ = aggregate_to_pdf(mu_bar=decoded_mu1, image_count=self.image_count, samples_per_image=self.samples_per_image, iH=self.image_size, iW=self.image_size, resolution=self.resolution)
    
    return mu1_bar__

    
  def train(self, mu1__, num_epochs=3000):
    self.image_count, self.image_size, _ = mu1__.shape

    learning_rate = 1e-3
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate,
                                 weight_decay=1e-5)

    #print("mu__", mu1__.shape)
    mu1_duplicated__ = torch.stack([mu__.clone() for mu__ in mu1__ for _ in range(self.samples_per_image)])
    #print("Duplicated PDFs for images in animation: ", mu1__.shape)
    mu1 = sample_from_images__(mu1_duplicated__)
    #print("mu1: Sampled images in animation: ", mu1.shape)

    input = mu1[:, None, :, :].to(device)

    if self.trained:
      output = self.model(input)
    else:
      done = False
      epoch = 0
      while not done:
        output = self.model(input)
        loss = criterion(output, input)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % int(num_epochs / 10) == 0:
         print('epoch [{}/{}], loss:{:.4f}'
               .format(epoch+1, num_epochs, loss.item()))

        if (loss.item() < 0.01 and epoch > 1000) or epoch > num_epochs:
          done = True
          
        epoch += 1


      self.trained = True
      torch.save(self.model.state_dict(), self.save_path())

    return mu1, output[:,0,:,:]
  
  def save_path(self):
    return f"unit_{self.unit_index}.pt"


class UnitStack:
  def __init__(self, resolution=resolution):
    self.units = []
    self.resolution = resolution
    
  def append_unit(self):
    if len(self.units) == 0:
      samples_per_image  = 10
    else:
      samples_per_image  = 1
    
    unit = Unit(len(self.units), samples_per_image=samples_per_image, a=4, resolution=self.resolution)
    self.units.append(unit)
    return unit

  def process(self, mu1__):
    return self.process_unit(0, mu1__)

  def process_unit(self, unit_index, mu1__):
    unit = self.units[unit_index]

    print(f"mu{unit_index}__        :", mu1__.shape)
    unit.train(mu1__, num_epochs=3000)

    h1__ = unit.up()
    print(f"h{unit_index}__         :", h1__.shape)
    
    if unit_index < len(self.units) - 1:
      unext_bar__ = self.process_unit(unit_index + 1, h1__)
    else:
      print("No next unit")
      unext_bar__ = h1__

    mu1_bar__ = unit.down(unext_bar__)
    print(f"mu{unit_index}_bar__    :", mu1_bar__.shape)
    
    return mu1_bar__



## Example

In [1]:
image_size = 16
image_count = image_size
np.random.seed(0)
torch.manual_seed(0)

images = generate_moving_line(image_size, image_size, count=image_count).float()
#print("Distinct images in animation: ", images.shape)
mu1__ = normal_distribution_table.to_pdf(images)
#print("mu1__: PDFs for images in animation: ", mu1__.shape)

unit_stack = UnitStack(resolution=resolution)
unit_stack.append_unit()
unit_stack.append_unit()
unit_stack.append_unit()
mu1_bar__ = unit_stack.process(mu1__)

sampled_images = sample_from_images__(mu1__)
for i in range(sampled_images.shape[0]):
  show_image(images[i].detach(), title=f"images {i}", vmin=0, vmax=1)
  show_image(mu1_bar__[i].detach(), title=f"mu1_bar__ {i}", vmin=0, vmax=1)
  show_image(sampled_images[i].detach(), title=f"sampled_images {i}", vmin=0, vmax=1)

NameError: ignored

In [0]:
# import os
# import glob

# files = glob.glob('./*.pt')
# for f in files:
#     os.remove(f)