Plotting functions. Copied direcly from the code online. Added the line to make it work on notebook


In [136]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
%matplotlib inline
# Plot image examples.
def plot_img(img, title):
    plt.figure()
    plt.imshow(img, interpolation='nearest')
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()

def img_stretch(img):
    img = img.astype(float)
    img -= np.min(img)
    img /= np.max(img)+1e-12
    return img

def img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1,
             border_color=0, stretch=False):
    ''' Tile images in a grid.
    If tile_shape is provided only as many images as specified in tile_shape
    will be included in the output.
    '''
    # Prepare images
    if stretch:
        imgs = img_stretch(imgs)
    imgs = np.array(imgs)
    if imgs.ndim != 3 and imgs.ndim != 4:
        raise ValueError('imgs has wrong number of dimensions.')
    n_imgs = imgs.shape[0]
    # Grid shape
    img_shape = np.array(imgs.shape[1:3])
    if tile_shape is None:
        img_aspect_ratio = img_shape[1] / float(img_shape[0])
        aspect_ratio *= img_aspect_ratio
        tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio)))
        tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio)))
        grid_shape = np.array((tile_height, tile_width))
    else:
        assert len(tile_shape) == 2
        grid_shape = np.array(tile_shape)
    # Tile image shape
    tile_img_shape = np.array(imgs.shape[1:])
    tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border
    # Assemble tile image
    tile_img = np.empty(tile_img_shape)
    tile_img[:] = border_color
    for i in range(grid_shape[0]):
        for j in range(grid_shape[1]):
            img_idx = j + i*grid_shape[1]
            if img_idx >= n_imgs:
                # No more images - stop filling out the grid.
                break
            img = imgs[img_idx]
            yoff = (img_shape[0] + border) * i
            xoff = (img_shape[1] + border) * j
            tile_img[yoff:yoff+img_shape[0], xoff:xoff+img_shape[1], ...] = img
    return tile_img

def conv_filter_tile(filters):
    n_filters, n_channels, height, width = filters.shape
    tile_shape = None
    if n_channels == 3:
        # Interpret 3 color channels as RGB
        filters = np.transpose(filters, (0, 2, 3, 1))
    else:
        # Organize tile such that each row corresponds to a filter and the
        # columns are the filter channels
        tile_shape = (n_channels, n_filters)
        filters = np.transpose(filters, (1, 0, 2, 3))
        filters = np.resize(filters, (n_filters*n_channels, height, width))
    filters = img_stretch(filters)
    return img_tile(filters, tile_shape=tile_shape)
    
def scale_to_unit_interval(ndar, eps=1e-8):
  """ Scales all values in the ndarray ndar to be between 0 and 1 """
  ndar = ndar.copy()
  ndar -= ndar.min()
  ndar *= 1.0 / (ndar.max() + eps)
  return ndar


def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),
                       scale_rows_to_unit_interval=True,
                       output_pixel_vals=True):
  """
  Transform an array with one flattened image per row, into an array in
  which images are reshaped and layed out like tiles on a floor.

  This function is useful for visualizing datasets whose rows are images,
  and also columns of matrices for transforming those rows
  (such as the first layer of a neural net).

  :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can
  be 2-D ndarrays or None;
  :param X: a 2-D array in which every row is a flattened image.

  :type img_shape: tuple; (height, width)
  :param img_shape: the original shape of each image

  :type tile_shape: tuple; (rows, cols)
  :param tile_shape: the number of images to tile (rows, cols)

  :param output_pixel_vals: if output should be pixel values (i.e. int8
  values) or floats

  :param scale_rows_to_unit_interval: if the values need to be scaled before
  being plotted to [0,1] or not


  :returns: array suitable for viewing as an image.
  (See:`PIL.Image.fromarray`.)
  :rtype: a 2-d array with same dtype as X.

  """

  assert len(img_shape) == 2
  assert len(tile_shape) == 2
  assert len(tile_spacing) == 2

  # The expression below can be re-written in a more C style as
  # follows :
  #
  # out_shape = [0,0]
  # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] -
  #                tile_spacing[0]
  # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] -
  #                tile_spacing[1]
  out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp
                      in zip(img_shape, tile_shape, tile_spacing)]

  if isinstance(X, tuple):
      assert len(X) == 4
      # Create an output numpy ndarray to store the image
      if output_pixel_vals:
          out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8')
      else:
          out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype)

      #colors default to 0, alpha defaults to 1 (opaque)
      if output_pixel_vals:
          channel_defaults = [0, 0, 0, 255]
      else:
          channel_defaults = [0., 0., 0., 1.]

      for i in range(4):
          if X[i] is None:
              # if channel is None, fill it with zeros of the correct
              # dtype
              out_array[:, :, i] = np.zeros(out_shape,
                      dtype='uint8' if output_pixel_vals else out_array.dtype
                      ) + channel_defaults[i]
          else:
              # use a recurrent call to compute the channel and store it
              # in the output
              out_array[:, :, i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals)
      return out_array

  else:
      # if we are dealing with only one channel
      H, W = img_shape
      Hs, Ws = tile_spacing

      # generate a matrix to store the output
      out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype)


      for tile_row in range(tile_shape[0]):
          for tile_col in range(tile_shape[1]):
              if tile_row * tile_shape[1] + tile_col < X.shape[0]:
                  if scale_rows_to_unit_interval:
                      # if we should scale values to be between 0 and 1
                      # do this by calling the `scale_to_unit_interval`
                      # function
                      this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape))
                  else:
                      this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)
                  # add the slice to the corresponding position in the
                  # output array
                  out_array[
                      tile_row * (H+Hs): tile_row * (H + Hs) + H,
                      tile_col * (W+Ws): tile_col * (W + Ws) + W
                      ] \
                      = this_img * (255 if output_pixel_vals else 1)
      return out_array


In [99]:
import argparse
import time
import numpy as np
import sys
import pickle
import torch
from google.colab import files
import torch.nn as nn

The upload function allows me to get local data onto colab.

In [15]:
uploaded = files.upload()

Saving data_sample_small to data_sample_small
Saving test_batch_small to test_batch_small


Unpickle has been taken directly from original code. Is not needed for smaller test cases, as I preprocessed them locally.

In [None]:
def unpickle(file):
    fo = open(file, 'rb')
    d = pickle.load(fo, encoding='latin1')
    fo.close()
    return {'x': np.cast[np.float32]((-127.5 + d['data'].reshape((10000,3,32,32)))/128.), 'y': np.array(d['labels']).astype(np.uint8)}

Load the data and store it

In [101]:
f1 = open("data_sample_small","rb")
f2 = open("test_batch_small","rb")
tr,test = pickle.load(f1),pickle.load(f2)

Set some parameters that are used later.

In [102]:
seed = 1
seed_data = 1
batch_size = 10
lr = 0.0003
count=50

Setting seeds and creating random objects

In [103]:
# fixed random seeds
rng_data = np.random.RandomState(seed_data)
rng = np.random.RandomState(seed)
theano_rng = np.random.RandomState(seed_data+seed)

Divying up the data.

In [104]:
trainx,trainy = tr['x'],tr['y']
testx,testy = test['x'],test['y']
trainx_unl = trainx.copy()
trainx_unl2 = trainx.copy()
nr_batches_train = trainx.shape[0]//batch_size
nr_batches_test = testx.shape[0]//batch_size

The generator and discriminator models. They have been designed based on the theano implementation.

In [107]:
# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.l1 = nn.Linear(100, 4*4*512, bias=True)
        nn.init.normal_(self.l1.weight, mean=0.0, std=0.05)
        self.l2 = nn.BatchNorm1d(4*4*512)
        #Change the view here to 512 x 4 x 4
        self.l3 = nn.ConvTranspose2d(512, 256, 5, stride =2, padding=2, output_padding=1) # should result in 256 x 8 x 8.
        nn.init.normal_(self.l3.weight, mean=0.0, std=0.05)
        self.l4 = nn.BatchNorm2d(256)
        self.l5 = nn.ConvTranspose2d(256, 128, 5, stride =2, padding=2, output_padding=1) # should result in 128 x 16 x 16.
        nn.init.normal_(self.l5.weight, mean=0.0, std=0.05)
        self.l6 = nn.BatchNorm2d(128)
        self.l7 = nn.utils.weight_norm(nn.ConvTranspose2d(128, 3, 5, stride =2, padding=2, output_padding=1))  # should result in 3 x 32 x 32.
        nn.init.normal_(self.l7.weight, mean=0.0, std=0.05)
        

    def forward(self, input):
        x = self.l1(input)
        x = self.l2(x)
        x = nn.functional.relu(x)
        x = x.view(x.size(0),512,4,4)
        x = self.l3(x)
        x = self.l4(x)
        x = nn.functional.relu(x)
        x = self.l5(x)
        x = self.l6(x)
        x = nn.functional.relu(x)
        x = self.l7(x)
        x = torch.tanh(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.l1 = nn.utils.weight_norm(nn.Conv2d(3, 96, 3, stride = 1, padding=1))
        self.l2 = nn.utils.weight_norm(nn.Conv2d(96, 96, 3, stride = 1, padding=1))
        self.l3 = nn.utils.weight_norm(nn.Conv2d(96, 96, 3, stride = 2, padding=1))
        self.l4 = nn.utils.weight_norm(nn.Conv2d(96, 192, 3, stride=1, padding=1))
        self.l5 = nn.utils.weight_norm(nn.Conv2d(192, 192, 3, stride=1, padding=1))
        self.l6 = nn.utils.weight_norm(nn.Conv2d(192, 192, 3, stride=2, padding=1))
        self.l7 = nn.utils.weight_norm(nn.Conv2d(192, 192, 3, stride=1, padding=0))
        self.l8 = nn.utils.weight_norm(nn.Conv2d(192, 192, 1, stride=1, padding=0))
        self.l9 = nn.utils.weight_norm(nn.Conv2d(192, 192, 1, stride=1, padding=0))
        self.l10 = nn.MaxPool2d(6)
        self.l11 = nn.utils.weight_norm(nn.Linear(192, 10, bias=True))

    def forward(self, input):
        x = nn.functional.dropout(input, p=0.2)
        x = self.l1(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l2(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l3(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = nn.functional.dropout(x, p=0.5)
        x = self.l4(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l5(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l6(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = nn.functional.dropout(x, p=0.5)
        x = self.l7(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l8(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l9(x)
        x = nn.functional.leaky_relu(x, negative_slope=0.2)
        x = self.l10(x)
        x = x.view(x.size(0),-1)
        x = self.l11(x)
        return x

This hook allows for feature matching at the second to last layer.

In [111]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [112]:
noise_dim = (batch_size, 100)
noise = theano_rng.uniform(size=noise_dim)
gen = Generator(0)
disc = Discriminator(0)

In [113]:

disc.l10.register_forward_hook(get_activation('l10'))


<torch.utils.hooks.RemovableHandle at 0x7f0f1487f5c0>

Weight initializations are from the original implementation.

In [None]:
def weights_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('ConvTranspose2d')!= -1:
        nn.init.normal_(m.weight.data, 0.0, 0.05)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.05)
        nn.init.constant_(m.bias.data, 0)

disc.apply(weights_init)
gen.apply(weights_init)

Data selection is exactly as the original paper. The optimiser used is Adam.

In [116]:
# select labeled data
inds = rng_data.permutation(trainx.shape[0])
trainx = trainx[inds]
trainy = trainy[inds]
txs = []
tys = []
for j in range(10):
    txs.append(trainx[trainy==j][:count])
    tys.append(trainy[trainy==j][:count])
txs = np.concatenate(txs, axis=0)
tys = np.concatenate(tys, axis=0)

# initialize optimisers
Doptim = torch.optim.Adam(disc.parameters(),lr=lr, betas = (0.9,0.999))
Goptim = torch.optim.Adam(gen.parameters(),lr=lr, betas = (0.9,0.999))

Training is made parallel to the original implementation. Theano functions were opened up and replaced by their actual implementations in each place.

In [None]:
for epoch in range(1200):
    begin = time.time()
    lr = lr*min(3-epoch/400,1)
    for param_group in Doptim.param_groups:
        param_group['lr'] = lr
    for param_group in Goptim.param_groups:
        param_group['lr'] = lr
    # This entire process needs to be replaced by a dataloader.
    trainx = []
    trainy = []
    for t in range(int(np.ceil(trainx_unl.shape[0]/float(txs.shape[0])))):
        inds = rng.permutation(txs.shape[0])
        trainx.append(txs[inds])
        trainy.append(tys[inds])
    trainx = np.concatenate(trainx, axis=0)
    trainy = np.concatenate(trainy, axis=0)
    trainx_unl = trainx_unl[rng.permutation(trainx_unl.shape[0])]
    trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])]
    
    # train
    loss_lab = 0.
    loss_unl = 0.
    train_err = 0.
    disc_avg_updates = None
    for t in range(nr_batches_train):
        ran_from = t*batch_size
        ran_to = (t+1)*batch_size
        x_lab,labels,x_unl,lr = torch.tensor(trainx[ran_from:ran_to]),torch.tensor(trainy[ran_from:ran_to]),torch.tensor(trainx_unl[ran_from:ran_to]),torch.tensor(lr)
        noise = torch.tensor(theano_rng.uniform(size=noise_dim))
        disc.train()
        gen.train()
        output_before_softmax_lab = disc(x_lab)
        output_before_softmax_unl = disc(x_unl)
        output_before_softmax_gen = disc(gen(noise.float()))
        l_lab_ind = [(i,j.item()) for i,j in zip(np.arange(batch_size),labels)]
        l_lab = torch.tensor([output_before_softmax_lab[x] for x in l_lab_ind])
        l_unl = torch.logsumexp(output_before_softmax_unl,1)
        l_gen = torch.logsumexp(output_before_softmax_gen,1)
        loss_lab = -torch.mean(l_lab) + torch.mean(torch.mean(torch.logsumexp(output_before_softmax_lab,1)))
        loss_unl = -0.5*torch.mean(l_unl) + 0.5*torch.mean(nn.functional.softplus(l_unl)) + 0.5*torch.mean(nn.functional.softplus(l_gen))
        total_loss = loss_lab+loss_unl
        train_err = torch.mean(torch.ne(torch.argmax(output_before_softmax_lab,dim=1),labels).float())
        Doptim.zero_grad()
        total_loss.backward()
        Doptim.step()

        ll, lu, te = loss_lab, loss_unl, train_err
        loss_lab += ll
        loss_unl += lu
        train_err += te

        disc_params = [param.data for param in disc.parameters()]
        if disc_avg_updates == None:
            disc_avg_updates = [p for p in disc_params]
        disc_avg_updates = [a+0.0001*(p-a) for p,a in zip(disc_params,disc_avg_updates)]
  
        noise = torch.tensor(theano_rng.uniform(size=noise_dim))
        x_unl,lr = trainx_unl2[t*batch_size:(t+1)*batch_size],lr
        disc.train()
        gen.train()
        output_unl = disc(torch.tensor(x_unl))
        # The below one needs to be used for feature mapping, but the loss function and optimiser will need to be redesigned for that.
        #output_unl = activation['l10']
        output_gen = disc(gen(torch.tensor(noise).float()))
        #output_gen = activation['l10']
        m1 = torch.mean(output_unl,dim=0)
        m2 = torch.mean(output_gen,dim=0)
        loss_gen = torch.mean(torch.abs(m1-m2)) # feature matching loss
        Goptim.zero_grad()
        Doptim.zero_grad()
        loss_gen.backward()
        Goptim.step()

    loss_lab /= nr_batches_train
    loss_unl /= nr_batches_train
    train_err /= nr_batches_train
    # test
    gen.eval()
    disc.eval()
    original_params = [param.data for param in disc.parameters()]
    for param,avg_param in zip(disc.parameters(),disc_avg_updates):
        param.data.copy_(avg_param)
    test_err = 0.
    for t in range(nr_batches_test):
        x_lab,labels = testx[t*batch_size:(t+1)*batch_size],testy[t*batch_size:(t+1)*batch_size]
        output_before_softmax = disc(torch.tensor(x_lab))
        test_batch_err = torch.mean(torch.ne(torch.argmax(output_before_softmax,dim=1),torch.tensor(labels)).float())
        test_err += test_batch_err
    for param,o_param in zip(disc.parameters(),original_params):
        param.data.copy_(o_param)
    test_err /= nr_batches_test
    # report
    print("Iteration %d, time = %ds, loss_lab = %.4f, loss_unl = %.4f, train err = %.4f, test err = %.4f" % (epoch, time.time()-begin, loss_lab, loss_unl, train_err, test_err))
    sys.stdout.flush()
    gen.eval()
    with torch.no_grad():
        noise = torch.tensor(theano_rng.uniform(size=noise_dim))
        sample_x = gen(torch.tensor(noise).float())
        sample_x = sample_x.numpy()
        img_bhwc = np.transpose(sample_x[:100,], (0, 2, 3, 1))
        img_tile_o = img_tile(img_bhwc, aspect_ratio=1.0, border_color=1.0, stretch=True)
        img = plot_img(img_tile_o, title='CIFAR10 samples')


In [None]:
    # Code to save params
    #np.savez('disc_params.npz', *[p.get_value() for p in disc_params])
    #np.savez('gen_params.npz', *[p.get_value() for p in gen_params])