In [0]:
from google.colab import drive, files
drive.mount('/content/drive', force_remount = True)

# UTILS

In [0]:
!pip install transformers

In [0]:
import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def is_image_file(filename):
    """Checks if a file is an image.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)


def find_classes(dir, classes_idx=None):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    if classes_idx is not None:
        assert type(classes_idx) == tuple
        start, end = classes_idx
        classes = classes[start:end]
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        if target not in class_to_idx:
            continue
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, classes_idx=None):
        self.classes_idx = classes_idx
        classes, class_to_idx = find_classes(root, self.classes_idx)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)


In [0]:
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf
import glob
import scipy.misc
import math
import sys

import glob
#filelist = glob.glob('/content/drive/My Drive/cs236/fakes/10platent/10pfakes/*.png')
#filelist = glob.glob('/content/drive/My Drive/cs236/fakes/fakesis/10p400/*.png')
#filelist = glob.glob('/content/drive/My Drive/cs236/fakes/cifarfake-c/*.png')
#filelist = glob.glob('/content/drive/My Drive/cs236/fakes/acwgan-gp2/*.png')
filelist = glob.glob('/content/drive/My Drive/cs236/reals/cifarreal2/*.png')
print(type(filelist))
print(len(filelist))
imgs = [np.array(Image.open(fname)) for fname in filelist]
print(type(imgs))
print(imgs[0].shape)

MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None

# Call this function with list of images. Each of elements should be a 
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
  assert(type(images) == list)
  assert(type(images[0]) == np.ndarray)
  assert(len(images[0].shape) == 3)
  assert(np.max(images[0]) > 10)
  assert(np.min(images[0]) >= 0.0)
  inps = []
  for img in images:
    img = img.astype(np.float32)
    inps.append(np.expand_dims(img, 0))
  bs = 1
  with tf.Session() as sess:
    preds = []
    n_batches = int(math.ceil(float(len(inps)) / float(bs)))
    for i in range(n_batches):
        sys.stdout.write(".")
        sys.stdout.flush()
        inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
        inp = np.concatenate(inp, 0)
        pred = sess.run(softmax, {'ExpandDims:0': inp})
        preds.append(pred)
    preds = np.concatenate(preds, 0)
    scores = []
    for i in range(splits):
      part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
      kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
      kl = np.mean(np.sum(kl, 1))
      scores.append(np.exp(kl))
    return np.mean(scores), np.std(scores)

# This function is called automatically.
def _init_inception():
  global softmax
  if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(MODEL_DIR, filename)
  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
  with tf.gfile.FastGFile(os.path.join(
      MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
  # Works with an arbitrary minibatch size.
  with tf.Session() as sess:
    pool3 = sess.graph.get_tensor_by_name('pool_3:0')
    ops = pool3.graph.get_operations()
    for op_idx, op in enumerate(ops):
        for o in op.outputs:
            shape = o.get_shape()
            shape = [s.value for s in shape]
            new_shape = []
            for j, s in enumerate(shape):
                if s == 1 and j == 0:
                    new_shape.append(None)
                else:
                    new_shape.append(s)
            o.set_shape(tf.TensorShape(new_shape))
    w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
    logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
    softmax = tf.nn.softmax(logits)

if softmax is None:
  _init_inception()

get_inception_score(imgs, splits=10)

<type 'list'>
10000


In [0]:
from transformers import *
import pdb
from torch.nn.utils import weight_norm as wn
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch


def bert_encoder():
    return BERTEncoder()


def class_embedding(n_classes, embedding_dim):
    return nn.Embedding(n_classes, embedding_dim)


def unconditional(n_classes, embedding_dim):
    return nn.Embedding(n_classes, embedding_dim)


class Embedder(nn.Module):
    def __init__(self, embed_size):
        super(Embedder, self).__init__()
        self.embed_size = embed_size

    def forward(self, class_labels, captions):
        raise NotImplementedError


class BERTEncoder(Embedder):
    '''
    pretrained model used to embed text to a 768 dimensional vector
    '''

    def __init__(self):
        super(BERTEncoder, self).__init__(embed_size=768)
        self.pretrained_weights = 'bert-base-uncased'
        self.tokenizer = BertTokenizer.from_pretrained(self.pretrained_weights)
        self.model = BertModel.from_pretrained(self.pretrained_weights)
        self.max_len = 50

    def tokenize(self, text_batch):
        text_token_ids = [
            torch.tensor(self.tokenizer.encode(string_, add_special_tokens=False, max_length=self.max_len)) for
            string_ in text_batch]
        padded_input = pad_sequence(text_token_ids, batch_first=True, padding_value=0)
        return padded_input

    def forward(self, class_labels, captions):
        '''
        :param class_labels : torch.LongTensor, class ids
        :param list captions: list of strings, sentences to embed
        :return: torch.tensor embeddings: embeddings of shape (batch_size,embed_size=768)
        '''

        padded_input = self.tokenize(captions)
        device = list(self.parameters())[0].device
        padded_input = padded_input.to(device)
        # takes the mean of the last hidden states computed by the pre-trained BERT encoder and return it
        return self.model(padded_input)[0].mean(dim=1)


I1130 19:32:52.937437 140251756554112 file_utils.py:39] PyTorch version 1.3.1+cu100 available.
  "You should really be using Python3!!! "


In [0]:
import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import collections
import time
import cPickle as pickle

_since_beginning = collections.defaultdict(lambda: {})
_since_last_flush = collections.defaultdict(lambda: {})

_iter = [0]
def tick():
	_iter[0] += 1

def plot(name, value):
	_since_last_flush[name][_iter[0]] = value
k = len("/content/drive/My drive/conwgan-gp/")
def flush():
	prints = []

	for name, vals in _since_last_flush.items():
		name2 = name[k:];prints.append("{}\t{}".format(name2, np.mean(vals.values())))
		_since_beginning[name2].update(vals)

		x_vals = np.sort(_since_beginning[name2].keys())
		y_vals = [_since_beginning[name2][x] for x in x_vals]

		plt.clf()
		plt.plot(x_vals, y_vals)
		plt.xlabel('iteration')
		plt.ylabel(name2)
		plt.savefig(name+'.jpg')

	print "iter {}\t{}".format(_iter[0], "\t".join(prints))
	_since_last_flush.clear()

	with open('log.pkl', 'wb') as f:
		pickle.dump(dict(_since_beginning), f, pickle.HIGHEST_PROTOCOL)

# MODELS

In [0]:
from torch import nn
from torch.autograd import grad
import torch
#Models taken and improved from https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py
DIM=64
OUTPUT_DIM=64*64*3

class MyConvo2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, he_init = True,  stride = 1, bias = True):
        super(MyConvo2d, self).__init__()
        self.he_init = he_init
        self.padding = int((kernel_size - 1)/2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=self.padding, bias = bias)

    def forward(self, input):
        output = self.conv(input)
        return output

class ConvMeanPool(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, he_init = True):
        super(ConvMeanPool, self).__init__()
        self.he_init = he_init
        self.conv = MyConvo2d(input_dim, output_dim, kernel_size, he_init = self.he_init)

    def forward(self, input):
        output = self.conv(input)
        output = (output[:,:,::2,::2] + output[:,:,1::2,::2] + output[:,:,::2,1::2] + output[:,:,1::2,1::2]) / 4
        return output

class MeanPoolConv(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, he_init = True):
        super(MeanPoolConv, self).__init__()
        self.he_init = he_init
        self.conv = MyConvo2d(input_dim, output_dim, kernel_size, he_init = self.he_init)

    def forward(self, input):
        output = input
        output = (output[:,:,::2,::2] + output[:,:,1::2,::2] + output[:,:,::2,1::2] + output[:,:,1::2,1::2]) / 4
        output = self.conv(output)
        return output

class DepthToSpace(nn.Module):
    def __init__(self, block_size):
        super(DepthToSpace, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, input_height, input_width, input_depth) = output.size()
        output_depth = int(input_depth / self.block_size_sq)
        output_width = int(input_width * self.block_size)
        output_height = int(input_height * self.block_size)
        t_1 = output.reshape(batch_size, input_height, input_width, self.block_size_sq, output_depth)
        spl = t_1.split(self.block_size, 3)
        stacks = [t_t.reshape(batch_size,input_height,output_width,output_depth) for t_t in spl]
        output = torch.stack(stacks,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size,output_height,output_width,output_depth)
        output = output.permute(0, 3, 1, 2)
        return output


class UpSampleConv(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, he_init = True, bias=True):
        super(UpSampleConv, self).__init__()
        self.he_init = he_init
        self.conv = MyConvo2d(input_dim, output_dim, kernel_size, he_init = self.he_init, bias=bias)
        self.depth_to_space = DepthToSpace(2)

    def forward(self, input):
        output = input
        output = torch.cat((output, output, output, output), 1)
        output = self.depth_to_space(output)
        output = self.conv(output)
        return output


class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, resample=None, hw=None):
        super(ResidualBlock, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.kernel_size = kernel_size
        self.resample = resample
        self.bn1 = None
        self.bn2 = None
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.resample = resample
        if resample == 'down':
            self.bn1 = nn.LayerNorm([input_dim, hw, hw])
            self.bn2 = nn.LayerNorm([input_dim, hw, hw])
        elif resample == 'up':
            self.bn1 = nn.BatchNorm2d(input_dim)
            self.bn2 = nn.BatchNorm2d(output_dim)
        elif resample == None:
            #TODO: ????
            self.bn1 = nn.BatchNorm2d(input_dim)
            self.bn2 = nn.BatchNorm2d(output_dim)
        else:
            raise Exception('invalid resample value')

        if resample == 'down':
            self.conv_shortcut = MeanPoolConv(input_dim, output_dim, kernel_size = 1, he_init = False)
            self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
            self.conv_2 = ConvMeanPool(input_dim, output_dim, kernel_size = kernel_size)
        elif resample == 'up':
            self.conv_shortcut = UpSampleConv(input_dim, output_dim, kernel_size = 1, he_init = False)
            self.conv_1 = UpSampleConv(input_dim, output_dim, kernel_size = kernel_size, bias = False)
            self.conv_2 = MyConvo2d(output_dim, output_dim, kernel_size = kernel_size)
        elif resample == None:
            self.conv_shortcut = MyConvo2d(input_dim, output_dim, kernel_size = 1, he_init = False)
            self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
            self.conv_2 = MyConvo2d(input_dim, output_dim, kernel_size = kernel_size)
        else:
            raise Exception('invalid resample value')

    def forward(self, input):
        if self.input_dim == self.output_dim and self.resample == None:
            shortcut = input
        else:
            shortcut = self.conv_shortcut(input)

        output = input
        if self.resample != 'down':
            output = self.bn1(output)
        output = self.relu1(output)
        output = self.conv_1(output)
        if self.resample != 'down':
            output = self.bn2(output)
        output = self.relu2(output)
        output = self.conv_2(output)

        return shortcut + output

class ReLULayer(nn.Module):
    def __init__(self, n_in, n_out):
        super(ReLULayer, self).__init__()
        self.n_in = n_in
        self.n_out = n_out
        self.linear = nn.Linear(n_in, n_out)
        self.relu = nn.ReLU()

    def forward(self, input):
        output = self.linear(input)
        output = self.relu(output)
        return output

class FCGenerator(nn.Module):
    def __init__(self, FC_DIM=512):
        super(FCGenerator, self).__init__()
        self.relulayer1 = ReLULayer(128, FC_DIM)
        self.relulayer2 = ReLULayer(FC_DIM, FC_DIM)
        self.relulayer3 = ReLULayer(FC_DIM, FC_DIM)
        self.relulayer4 = ReLULayer(FC_DIM, FC_DIM)
        self.linear = nn.Linear(FC_DIM, OUTPUT_DIM)
        self.tanh = nn.Tanh()

    def forward(self, input):
        output = self.relulayer1(input)
        output = self.relulayer2(output)
        output = self.relulayer3(output)
        output = self.relulayer4(output)
        output = self.linear(output)
        output = self.tanh(output)
        return output

class GoodGenerator(nn.Module):
    def __init__(self, dim=DIM,output_dim=OUTPUT_DIM):
        super(GoodGenerator, self).__init__()

        self.dim = dim

        self.ln1 = nn.Linear(128, 4*4*8*self.dim)
        #self.rb1 = ResidualBlock(8*self.dim, 8*self.dim, 3, resample = None)
        self.rb2 = ResidualBlock(8*self.dim, 4*self.dim, 3, resample = 'up')
        self.rb3 = ResidualBlock(4*self.dim, 2*self.dim, 3, resample = 'up')
        self.rb4 = ResidualBlock(2*self.dim, 1*self.dim, 3, resample = 'up')
        self.bn  = nn.BatchNorm2d(self.dim)

        self.conv1 = MyConvo2d(1*self.dim, 3, 3)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, input):
        output = self.ln1(input.contiguous())
        output = output.view(-1, 8*self.dim, 4, 4)
        #output = self.rb1(output)
        output = self.rb2(output)
        output = self.rb3(output)
        output = self.rb4(output)

        output = self.bn(output)
        output = self.relu(output)
        output = self.conv1(output)
        output = self.tanh(output)
        output = output.view(-1, OUTPUT_DIM)
        return output

class GoodDiscriminator(nn.Module):
    def __init__(self, dim=DIM, num_class=2):
        super(GoodDiscriminator, self).__init__()

        self.dim = dim
        self.num_class = num_class
        self.conv1 = MyConvo2d(3, self.dim, 3, he_init = False)
        self.rb1 = ResidualBlock(self.dim, 2*self.dim, 3, resample = 'down', hw=DIM)
        self.rb2 = ResidualBlock(2*self.dim, 4*self.dim, 3, resample = 'down', hw=int(DIM/2))
        self.rb3 = ResidualBlock(4*self.dim, 8*self.dim, 3, resample = 'down', hw=int(DIM/4))
        #self.rb4 = ResidualBlock(8*self.dim, 8*self.dim, 3, resample = 'down', hw=int(DIM/8))
        self.ln1 = nn.Linear(4*4*8*self.dim, 1)

        self.ln2 = nn.Linear(4*4*8*self.dim, self.num_class)

    def forward(self, input):
        output = input.contiguous()
        output = output.view(-1, 3, DIM, DIM)
        output = self.conv1(output)
        output = self.rb1(output)
        output = self.rb2(output)
        output = self.rb3(output)
        #output = self.rb4(output)
        output = output.view(-1, 4*4*8*self.dim)
        output_wgan = self.ln1(output)
        output_wgan = output_wgan.view(-1)
        output_congan = self.ln2(output)
        return output_wgan, output_congan

  class Classifier(nn.Module):
    def __init__(self, dim=DIM):
        super(Classifier, self).__init__()

        self.dim = dim
        self.conv1 = MyConvo2d(3, self.dim, 3, he_init = False)
        self.rb1 = ResidualBlock(self.dim, 2*self.dim, 3, resample = 'down', hw=DIM)
        self.rb2 = ResidualBlock(2*self.dim, 4*self.dim, 3, resample = 'down', hw=int(DIM/2))
        self.rb3 = ResidualBlock(4*self.dim, 8*self.dim, 3, resample = 'down', hw=int(DIM/4))
        #self.rb4 = ResidualBlock(8*self.dim, 8*self.dim, 3, resample = 'down', hw=int(DIM/8))
        self.ln1 = nn.Linear(4*4*8*self.dim + 10, 1)

        self.sigmoid  = nn.Sigmoid()

    def forward(self, input, label):
        output = input.contiguous()
        output = output.view(-1, 3, DIM, DIM)
        output = self.conv1(output)
        output = self.rb1(output)
        output = self.rb2(output)
        output = self.rb3(output)
        #output = self.rb4(output)
        output = output.view(-1, 4*4*8*self.dim)
        output = torch.cat((output, label), 1)
        output = self.ln1(output)
        
        realfake = self.sigmoid(output)
        return realfake

# TRAIN ACGAN/ACWGAN

In [0]:
import os, sys
sys.path.append(os.getcwd())

#Models taken and improved from https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py
import time
import functools
import argparse

import numpy as np
#import sklearn.datasets

#import libs as lib
#import libs.plot
#from tensorboardX import SummaryWriter

import pdb
#import gpustat

#from models.conwgan import *

import torch
import torchvision
from torch import nn
from torch import autograd
from torch import optim
from torchvision import transforms, datasets
from torch.autograd import grad
from timeit import default_timer as timer

import torch.nn.init as init

DATA_DIR = '/content/drive/My Drive/'

cifar_text_labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
encoder = BERTEncoder()


NUM_CLASSES = 10

load_model = False
START_ITER = 0 
OUTPUT_PATH = '/content/drive/My Drive/conwgan-gp/'


DIM = 32 # Model dimensionality
CRITIC_ITERS = 5 # How many iterations to train the critic for
GENER_ITERS = 1
N_GPUS = 1 # Number of GPUs
BATCH_SIZE = 100# Batch size. Must be a multiple of N_GPUS
END_ITER = 100000 # How many iterations to train for
LAMBDA = 10 # Gradient penalty lambda hyperparameter
OUTPUT_DIM = 32*32*3 # Number of pixels in each iamge
ACGAN_SCALE = 1. # How to scale the critic's ACGAN loss relative to WGAN loss
ACGAN_SCALE_G = 1. # How to scale generator's ACGAN loss relative to WGAN loss



def weights_init(m):
    if isinstance(m, MyConvo2d): 
        if m.conv.weight is not None:
            if m.he_init:
                init.kaiming_uniform_(m.conv.weight)
            else:
                init.xavier_uniform_(m.conv.weight)
        if m.conv.bias is not None:
            init.constant_(m.conv.bias, 0.0)
    if isinstance(m, nn.Linear):
        if m.weight is not None:
            init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0.0)

def load_data(path_to_folder, classes):
    dataset = datasets.CIFAR10(
          root=path_to_folder, download=True,
          transform=transforms.Compose([
              transforms.Scale(32),
              transforms.ToTensor(),
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
          ]))
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
    return dataset_loader

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
    alpha = alpha.view(BATCH_SIZE, 3, 32, 32)
    alpha = alpha.to(device)

    fake_data = fake_data.view(BATCH_SIZE, 3, 32, 32)
    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())

    interpolates = interpolates.to(device)
    interpolates.requires_grad_(True)   

    disc_interpolates, _ = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)                              
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

def generate_image(netG, noise=None):
    if noise is None:
        #rand_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
        rand_label = np.repeat(np.arange(10), 100//10)
        noise = gen_rand_noise_with_label(rand_label)
    with torch.no_grad():
        noisev = noise
    samples = netG(noisev)
    samples = samples.view(BATCH_SIZE, 3, 32, 32)

    samples = samples * 0.5 + 0.5

    return samples

def gen_rand_noise_with_label(label=None):
    if label is None:
        label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
    noise = np.random.normal(0, 1, (BATCH_SIZE, 128))

    captions = [cifar_text_labels[per_label] for per_label in label]
    embedding = encoder(label, captions)
    embedding = embedding.detach().numpy()

    noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = embedding[:, :10]

    noise = torch.from_numpy(noise).float()
    noise = noise.to(device)

    return noise


cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
fixed_label = []
for c in range(BATCH_SIZE):
    fixed_label.append(c%NUM_CLASSES)
fixed_noise = gen_rand_noise_with_label(fixed_label)

if load_model:
    aG = torch.load(OUTPUT_PATH + "generator" + str(START_ITER) + ".pt")

    aD = torch.load(OUTPUT_PATH + "discriminator" + str(START_ITER) + ".pt")
    
else:
    aG = GoodGenerator(DIM,OUTPUT_DIM)
    aD = GoodDiscriminator(DIM, NUM_CLASSES)
    
    aG.apply(weights_init)
    aD.apply(weights_init)

LR = 1e-4
optimizer_g = torch.optim.Adam(aG.parameters(), lr=LR, betas=(0,0.9))
optimizer_d = torch.optim.Adam(aD.parameters(), lr=LR, betas=(0,0.9))

aux_criterion = nn.CrossEntropyLoss() # nn.NLLLoss()

one = torch.FloatTensor([1])
mone = one * -1
aG = aG.to(device)
aD = aD.to(device)
one = one.to(device)
mone = mone.to(device)

#Reference: https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py

dataloader = load_data(DATA_DIR, cifar_text_labels)
dataiter = iter(dataloader)
for iteration in range(START_ITER+1, END_ITER):
    start_time = time.time()
    start = timer()
    #---------------------TRAIN C------------------------
    '''
    for p in aC.parameters():
        p.requires_grad_(True)
    for l in range(CITER):
      f_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
      noise = gen_rand_noise_with_label(f_label)
      with torch.no_grad():
          noisev = noise  # totally freeze G, training D
      fake_data = aG(noisev).detach()
      batch = next(dataiter, None)
      if batch is None:
          dataiter = iter(dataloader)
          batch = dataiter.next()
      real_data = batch[0] #batch[1] contains labels
      real_data.requires_grad_(True)

      real_data = real_data.to(device)
      p_fake = aC(fake_data)
      c_err_fake = c_criterion(p_fake, zone)
      p_real = aC(real_data)
      c_err_real = c_criterion(p_real, one)
      k = c_err_real + c_err_fake
      (c_err_real + c_err_fake).backward()
      optimizer_c.step()
    for p in aC.parameters():
        p.requires_grad_(False)
    '''
    #---------------------TRAIN G------------------------
    for p in aD.parameters():
        p.requires_grad_(False) 

    gen_cost = None
    for i in range(GENER_ITERS):
        aG.zero_grad()
        f_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
        noise = gen_rand_noise_with_label(f_label)
        noise.requires_grad_(True)
        fake_data = aG(noise)
        gen_cost, gen_aux_output = aD(fake_data)
        '''
        probs = aC(fake_data)
        #print(torch.max(probs))
        sniw = 1
        if iteration > 2005:
          iw = probs / (1 - probs)
          iw = iw.clamp(0.01, 0.99)
          sum_iw = torch.sum(iw)
          sniw = iw / sum_iw
        '''

        aux_label = torch.from_numpy(f_label).long()
        aux_label = aux_label.to(device)
        aux_errG = aux_criterion(gen_aux_output, aux_label).mean()
        gen_cost = -gen_cost.mean()
        g_cost = ACGAN_SCALE_G*aux_errG + gen_cost
        g_cost.backward()
    
    optimizer_g.step()
    #---------------------TRAIN D------------------------
    for p in aD.parameters():
        p.requires_grad_(True)  
    for i in range(CRITIC_ITERS):
        
        start = timer()
        aD.zero_grad()

        f_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
        noise = gen_rand_noise_with_label(f_label)
        with torch.no_grad():
            noisev = noise
        fake_data = aG(noisev).detach()
        batch = next(dataiter, None)
        if batch is None:
            dataiter = iter(dataloader)
            batch = dataiter.next()
        real_data = batch[0]
        real_data.requires_grad_(True)
        real_label = batch[1]

        #start = timer()
        real_data = real_data.to(device)
        real_label = real_label.to(device)

        # train with real data
        disc_real, aux_output = aD(real_data)
        aux_errD_real = aux_criterion(aux_output, real_label)
        errD_real = aux_errD_real.mean()
        disc_real = disc_real.mean()


        # train with fake data
        disc_fake, aux_output = aD(fake_data)
        disc_fake = disc_fake.mean()

        gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data)
        #aux_errD_fake = aux_criterion(aux_output, f_label)
        #errD_fake = aux_errD_fake.mean()
        disc_cost = disc_fake - disc_real + gradient_penalty
        disc_acgan = errD_real #+ errD_fake
        (disc_cost + ACGAN_SCALE*disc_acgan).backward()
        w_dist = disc_fake  - disc_real
        optimizer_d.step()
        #for p in aD.parameters():
        #    p.data.clamp_(-0.01, 0.01)
        #------------------VISUALIZATION----------
        if i == CRITIC_ITERS-1:
            plot(OUTPUT_PATH + 'Disc Cost', disc_cost.cpu().data.numpy())
            plot(OUTPUT_PATH + 'AC Disc Cost', disc_acgan.cpu().data.numpy())

            plot(OUTPUT_PATH + 'time', time.time() - start_time)
            plot(OUTPUT_PATH + 'Gen Cost', gen_cost.cpu().data.numpy())
            plot(OUTPUT_PATH + 'AC Gen Cost', aux_errG.cpu().data.numpy())
            plot(OUTPUT_PATH + 'wasserstein distance', w_dist.cpu().data.numpy())
    if iteration % 100==99:
        gen_images = generate_image(aG, fixed_noise)
        torchvision.utils.save_image(gen_images, OUTPUT_PATH + 'samples_{}.png'.format(iteration), nrow=10, padding=2)
#----------------------Save model----------------------
        torch.save(aG, OUTPUT_PATH + "generator" + str(iteration) + ".pt")
        torch.save(aD, OUTPUT_PATH + "discriminator" + str(iteration) + ".pt")
    if (iteration < 50) or (iteration % 100 == 99):
        flush()
    tick()



# TRAIN CLASSIFIER

In [0]:
import os, sys
sys.path.append(os.getcwd())


import time
import functools
import argparse

import numpy as np
#import sklearn.datasets

#import libs as lib
#import libs.plot
#from tensorboardX import SummaryWriter

import pdb
#import gpustat

#from models.conwgan import *

import torch
import torchvision
from torch import nn
from torch import autograd
from torch import optim
from torchvision import transforms, datasets
from torch.autograd import grad
from timeit import default_timer as timer

import torch.nn.init as init

DATA_DIR = '/content/drive/My Drive/'
cifar_text_labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
encoder = BERTEncoder()

NUM_CLASSES = 10


START_ITER = 9699 # starting iteration 
OUTPUT_PATH = '/content/drive/My Drive/conwgan-iw/'

DIM = 32 # Model dimensionality

def weights_init(m):
    if isinstance(m, MyConvo2d): 
        if m.conv.weight is not None:
            if m.he_init:
                init.kaiming_uniform_(m.conv.weight)
            else:
                init.xavier_uniform_(m.conv.weight)
        if m.conv.bias is not None:
            init.constant_(m.conv.bias, 0.0)
    if isinstance(m, nn.Linear):
        if m.weight is not None:
            init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0.0)

def load_data(path_to_folder, classes):

    dataset = datasets.CIFAR10(
          root=path_to_folder, download=True,
          transform=transforms.Compose([
              transforms.Scale(32),
              transforms.ToTensor(),
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
          ]))
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
    return dataset_loader



def generate_image(netG, noise=None):
    if noise is None:
        #rand_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
        rand_label = np.repeat(np.arange(10), 100//10)
        noise = gen_rand_noise_with_label(rand_label)
    with torch.no_grad():
        noisev = noise
    samples = netG(noisev)
    samples = samples.view(BATCH_SIZE, 3, 32, 32)

    samples = samples * 0.5 + 0.5

    return samples

def gen_rand_noise_with_label(label=None):
    if label is None:
        label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
    noise = np.random.normal(0, 1, (BATCH_SIZE, 128))

    captions = [cifar_text_labels[per_label] for per_label in label]
    embedding = encoder(label, captions)
    embedding = embedding.detach().numpy()

    noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = embedding[:, :10]

    noise = torch.from_numpy(noise).float()
    noise = noise.to(device)

    return noise


cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")

aG = torch.load(OUTPUT_PATH + "generator" + str(START_ITER) + ".pt")
aC = Classifier(DIM)
aC.apply(weights_init)
LR = 1e-4
optimizer_c = torch.optim.Adam(aC.parameters(), lr=LR, betas=(0,0.9))

#aux_criterion = nn.CrossEntropyLoss() # nn.NLLLoss()

c_criterion = nn.BCELoss()#nn.CrossEntropyLoss()#nn.BCELoss()

one = torch.FloatTensor([1])
one = torch.ones([100, 1], dtype=torch.float32)
zone = torch.zeros([100, 1], dtype = torch.float32)
mone = one * -1
aG = aG.to(device)
#aD = aD.to(device)
aC = aC.to(device)
one = one.to(device)
mone = mone.to(device)
zone = zone.to(device)

dataloader = load_data(DATA_DIR, TRAINING_CLASS)
dataiter = iter(dataloader)

for iteration in range(1, 5000):
    start_time = time.time()
    for p in aC.parameters():
        p.requires_grad_(True) 
    f_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
    noise = gen_rand_noise_with_label(f_label)
    with torch.no_grad():
        noisev = noise 
    fake_data = aG(noisev).detach()
    batch = next(dataiter, None)
    if batch is None:
        dataiter = iter(dataloader)
        batch = dataiter.next()
    real_data = batch[0]
    real_data = real_data.to(device)
    real_labels = batch[1].to(device)
    y = torch.LongTensor(f_label).to(device)
    o_h = torch.nn.functional.one_hot(y, num_classes=10).type(torch.FloatTensor).to(device)
    p_fake = aC(fake_data, o_h)
    
    c_err_fake = c_criterion(p_fake, zone)
    o_h = torch.nn.functional.one_hot(real_labels, num_classes=10).type(torch.FloatTensor).to(device)
    p_real = aC(real_data, o_h)
    
    c_err_real = c_criterion(p_real, one)
    k = (c_err_real + c_err_fake)
    k.backward()
    optimizer_c.step()
    torch.save(aC, OUTPUT_PATH + "classifier10b.pt")
print("DONE")

In [0]:
import os, sys
sys.path.append(os.getcwd())


import time
import functools
import argparse

import numpy as np
#import sklearn.datasets

#import libs as lib
#import libs.plot
#from tensorboardX import SummaryWriter

import pdb
#import gpustat

#from models.conwgan import *

import torch
import torchvision
from torch import nn
from torch import autograd
from torch import optim
from torchvision import transforms, datasets
from torch.autograd import grad
from timeit import default_timer as timer

import torch.nn.init as init

DATA_DIR = '/content/drive/My Drive/'


cifar_text_labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
encoder = BERTEncoder()

NUM_CLASSES = 10

START_ITER = 9699 # starting iteration 



N_GPUS = 1 # Number of GPUs
BATCH_SIZE = 1000# Batch size. Must be a multiple of N_GPUS

def weights_init(m):
    if isinstance(m, MyConvo2d): 
        if m.conv.weight is not None:
            if m.he_init:
                init.kaiming_uniform_(m.conv.weight)
            else:
                init.xavier_uniform_(m.conv.weight)
        if m.conv.bias is not None:
            init.constant_(m.conv.bias, 0.0)
    if isinstance(m, nn.Linear):
        if m.weight is not None:
            init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0.0)

def load_data(path_to_folder, classes):


    dataset = datasets.CIFAR10(
          root=path_to_folder, download=True,
          transform=transforms.Compose([
              transforms.Scale(32),
              transforms.ToTensor(),
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
          ]))
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
    return dataset_loader



def generate_image(netG, noise=None):
    if noise is None:
        #rand_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
        rand_label = np.repeat(np.arange(10), 100//10)
        noise = gen_rand_noise_with_label(rand_label)
    with torch.no_grad():
        noisev = noise
    samples = netG(noisev)
    samples = samples.view(BATCH_SIZE, 3, 32, 32)

    samples = samples * 0.5 + 0.5

    return samples

label = np.arange(10)
captions = [cifar_text_labels[per_label] for per_label in label]
embedding = encoder(label, captions)
embedding = embedding.detach().numpy()
embedding = embedding[:, :10]

def gen_rand_noise_with_label(label=None):
    if label is None:
        label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
    #attach label into noise
    noise = np.random.normal(0, 1, (BATCH_SIZE, 128))

    captions = [cifar_text_labels[per_label] for per_label in label]
    embedding = encoder(label, captions)
    embedding = embedding.detach().numpy()

    noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = embedding[:, :10]

    noise = torch.from_numpy(noise).float()
    noise = noise.to(device)

    return noise


cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")

aG = torch.load(OUTPUT_PATH + "generator" + str(START_ITER) + ".pt")
aC = torch.load(OUTPUT_PATH + "classifier10b.pt")


for iteration in range(1, 15):
    start_time = time.time()
    #for p in aD.parameters():
    #    p.requires_grad_(False)  # freeze D

    for p in aC.parameters():
        p.requires_grad_(True)  # freeze C
    if iteration >= 0:
          for l in range(1):
            #print l
            f_label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
            noise = gen_rand_noise_with_label(f_label)
            with torch.no_grad():
                noisev = noise  # totally freeze G, training D
            fake_data = aG(noisev).detach()
            #end = timer(); #print(f'---gen G elapsed time: {end-start}')
            #start = timer()
            batch = next(dataiter, None)
            if batch is None:
                dataiter = iter(dataloader)
                batch = dataiter.next()
            real_data = batch[0] #batch[1] contains labels
            #print("r_label" + str(r_label))
            #end = timer(); #print(f'---load real imgs elapsed time: {end-start}')

            #start = timer()
            real_data = real_data.to(device)
            real_labels = batch[1].to(device)
            y = torch.LongTensor(f_label).to(device)
            o_h = torch.nn.functional.one_hot(y, num_classes=10).type(torch.FloatTensor).to(device)
            p_fake = aC(fake_data, o_h)
            etheta = p_fake.mean()
            elogtheta = torch.log(p_fake).mean()
            print("E_t(w): ", etheta)
            print("E_t(log(w)): ", elogtheta)
            
            #c_err_fake = c_criterion(p_fake, zone)
            o_h = torch.nn.functional.one_hot(real_labels, num_classes=10).type(torch.FloatTensor).to(device)
            p_real = aC(real_data, o_h)
            ereal = p_real.mean()
            elogreal = torch.log(p_real).mean()
            print("E_r(w): ", ereal)
            print("E_r(log(w)): ", elogreal)
            #print("REAL PRED")
            #print(p_real[0:5])
            #print(real_labels[0:5])
            
            #c_err_real = c_criterion(p_real, one)
            #k = (c_err_real + c_err_fake)
            #k.backward()
            #print(k.shape)
            #optimizer_c.step()
            #if l==(CITER - 1):QDZ
            #  print(k)
          #print(iteration)
print("DONE")

# SAMPLING

In [0]:

aG = torch.load(OUTPUT_PATH + "generator" + str(START_ITER) + ".pt")
aC = torch.load(OUTPUT_PATH + "classifier10b.pt")
cifar_text_labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
coarse_labels = ["plane", "car", "bird", "pet", "animal", "pet", "frog", "animal", "boat", "car"]
def gen_rand_noise_with_coarse_label(label=None):
    if label is None:
        label = np.random.randint(0, NUM_CLASSES, BATCH_SIZE)
    #attach label into noise
    noise = np.random.normal(0, 1, (BATCH_SIZE, 128))
    captions = [coarse_labels[per_label] for per_label in label]
    embedding = encoder(label, captions)
    embedding = embedding.detach().numpy()

    #eval_noise_[np.arange(opt.batchSize), :opt.embed_size] = embedding[:, :opt.embed_size]
    #prefix = np.zeros((BATCH_SIZE, NUM_CLASSES))
    #prefix[np.arange(BATCH_SIZE), label] = 1
    #noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = prefix[np.arange(BATCH_SIZE)]
    #embedd = np.zeros((100, 10))
    #for i in range(100):
    #  embedd[i] = embedding[label[i]]
    #noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = embedd
    noise[np.arange(BATCH_SIZE), :NUM_CLASSES] = embedding[:, :10]

    noise = torch.from_numpy(noise).float()
    noise = noise.to(device)

    return noise
j = 10000
k = 0
print("GENNING")
while True:
  #f_label = np.repeat(np.arange(10), 1000//10)
  #noise = gen_rand_noise_with_label(f_label)
  #noise = gen_rand_noise_with_coarse_label(f_label)
  #fake = torch.zeros(100, 3, 32, 32)
  #fakes.destroy()
  #torch.cuda.empty_cache() 
  #fakes = generate_image(aG, noise=noise)
  for i in range(10):
    f_label = np.repeat(i, 10)
    #noise = gen_rand_noise_with_label(f_label)
    noise = gen_rand_noise_with_coarse_label(f_label)
    samples = generate_image(aG, noise=noise)
    #s = 100 * i
    #e = 100 * (i + 1)
    #samples = fakes[s:e,:,:,:]
    #probs = aC(samples)
    y = torch.LongTensor(f_label).to(device)
    o_h = torch.nn.functional.one_hot(y, num_classes=10).type(torch.FloatTensor).to(device)
    #y = 
    probs = aC(samples, o_h)[:,0]
    #print(probs.shape)
    #_, probs = aD(samples)
    #probs = probs[:, i]
    #print(probs.shape)
    #probs = probs.clamp(0.001, 0.999)
    iw = probs/(1-probs)
    lfiw = iw / iw.sum()
    #best = torch.argmax(torch.multinomial(lfiw, 10).sample())
    best = np.random.multinomial(1, lfiw.cpu().detach().numpy(), size=1).argmax()
    #print(lfiw)
    #print(best)
    vutils.save_image(samples[best,:,:,:], '/content/drive/My Drive/cs236/fakes/acwgan-gp-c-iw/real_samples%d.png' %  (k))
    k += 1
    #best = torch.argsort(probs, descending = True)
    #tops = iw[best]
    #print(tops[0:10])

    #fs = 10*i
    #fe = 10*(i + 1)
    #print(probs[best[0:10]])
    #top10 = samples[best[0:10], :, :, :]
    #print(probs[best])
    #print(top10.shape)
    #fake[fs:fe, :, :, :] = top10
  #filename = '/content/drive/My Drive/cs236/acgan-gp-Biw.png'
  #vutils.save_image(
  #    fake,
  #    filename,
  #    nrow = 10,
  #    padding = 2
  #)
  #for m,im in enumerate(fake):
  #  #vutils.save_image(im, '/content/drive/My Drive/cs236/fakes/10platent/10pfakes/real_samples%d.png' %  (m + k))
  #  vutils.save_image(im, '/content/drive/My Drive/cs236/fakes/acwgan-gp-Biw/real_samples%d.png' %  (m + k))
  #k += 100
  print(k)
  if k >= j:
    break
print("DONE")