In [1]:

import matplotlib.pyplot as plt
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import tensorflow_datasets as tfds
import torchvision
from tensorflow.keras.datasets import cifar10

from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# In[2]:


## All Hyperparams should go here. CGAN, CAEGAN, ICAEGAN
img_size = 32 # can use this to mofidy data size to fit this model
n_epochs = 10 #50? depends on max_sampels
print_stride = 1
#n_samples = 8000 #80k, 10k is fine
bs = 16 # 64

z_dim = 100

learning_rate = 0.0002



In [None]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(img_size),
     transforms.Normalize([0.5], [0.5])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=bs,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(next(iter(train_loader))[0].shape)
print(next(iter(train_loader))[1].shape)

first_samp = next(iter(train_loader))[0][0]
print(torch.min(first_samp), torch.max(first_samp))
plt.imshow(0.5*(first_samp.permute(1,2,0) + 1))

print(next(iter(train_loader))[1][0])



#y_train = torch.nn.functional.one_hot(torch.tensor(y_train).to(torch.int64))
#print(y_train.shape)
#print(y_train)

#y_train = torch.squeeze(y_train)
#print(y_train.shape)
#print(torch.sum(y_train, dim = 0))


# In[4]:


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Condition Up-Embedder:
        self.fc1 = nn.Linear(c_dim, img_size**2, bias = False)
        # Discriminator:
        self.conv1 = nn.Conv2d(4, 128, 4, 2, 1) # (bs, 3 + , img_size, img_size)
        self.conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(1024)
        self.conv5 = nn.Conv2d(1024, 1, 2, 1, 0)

    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])

    def forward(self, x, c):
        c = torch.tanh(self.fc1(c.view(mini_batch, c_dim))).view(mini_batch, 1, img_size, img_size) # Tanh: Since x is in (-1,1), c should probably too
        #print(x.shape, c.shape)
        x = torch.cat((x, c), dim = 1)
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = torch.sigmoid(self.conv5(x))
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## Decoding:
        self.deconv1v = nn.ConvTranspose2d(v_dim, 1024, 4, 1, 0, bias = False) # Not sure how this looks
        self.deconv1c = nn.ConvTranspose2d(c_dim, 1024, 4, 1, 0, bias = False) # Input: (bs, cdim+v_dim, 1, 1)

        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024+1024, 512, 4, 2, 1, bias = False)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False)
        self.deconv4_bn = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(128, 3, 3, 1, 1)

    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])

    def forward(self, v, c):
        v = self.deconv1_bn(self.deconv1v(v))
        c = self.deconv1_bn(self.deconv1c(c))
        x = torch.cat((v, c), dim = 1) #stack on channel dim, should be (bs, vdim+cdim, 1, 1). Not sure here
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))
        return x


# In[5]:


def normal_init(m):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0.0, 0.02)
        #m.bias.data.zero_()

def get_codes(size, hardware = device, hot = True):
    if hot == True:
        return one_hot_embedding(torch.randint(c_dim, size = (size, 1), device = hardware))

    else:
        return torch.randint(c_dim, size = (size, 1), device = hardware)

def one_hot_embedding(labels):
    #y = torch.eye(num_classes)
    #return y[labels]
    #return torch.nn.functional.one_hot(labels)[:,1:]

    labels = torch.nn.functional.one_hot(torch.tensor(labels).to(torch.int64), num_classes = c_dim)
    return torch.squeeze(labels)


# In[6]:


def print_g_sample():
    with torch.no_grad():
        codes = one_hot_embedding(torch.tensor(list(range(9)), device = device)).view(9,c_dim,1,1).float()
        varis = torch.randn((9, v_dim,1,1), device = device) # walk from [0,...,0] to [1,...,1]
        generated = .5*(G(varis, codes).cpu() + 1)
        for i in range(9):
            plt.subplot(330 + 1 + i)
            # plot raw pixel data
            element = generated[i,:].permute(1,2,0)
            plt.imshow(element)
        plt.show()


# In[7]:




G = Generator()
D = Discriminator()
G.weight_init()
D.weight_init()
G.to(device)
D.to(device)

BCE_loss = nn.BCELoss()
learning_rate = 0.0002
beta_1 = 0.5
beta_2 = 0.999

G_optimizer = optim.Adam(G.parameters(),
                         lr = learning_rate,
                         betas = (beta_1, beta_2))

D_optimizer = optim.Adam(D.parameters(),
                         lr = learning_rate,
                         betas = (beta_1, beta_2))


# In[8]:


G_loss_tracker, D_loss_tracker = [], []
for epoch in range(1, n_epochs+1):

    D_losses = []
    G_losses = []

    for X, code in train_loader:
        mini_batch = X.size()[0]
        X = X.to(device)
        code = code.to(device)
        code = one_hot_embedding(code).float()


        ## Discriminator Training
        for param in D.parameters():
            param.grad = None


        y_real = torch.ones((mini_batch,1,1,1), device = device)*D_real_scale # Sometimes .9, .1
        y_fake = torch.ones((mini_batch,1,1,1), device = device)*D_fake_scale # 

        rand_v = torch.randn((mini_batch, v_dim, 1, 1), device = device)
        rand_c = get_codes(mini_batch).view(mini_batch, c_dim, 1, 1).float()

        #print(X.shape, code.shape)
        D_real_out = D(X, code)
        D_real_loss = BCE_loss(D_real_out, y_real)

        X_fake = G(rand_v, rand_c)
        D_fake_out = D(X_fake, rand_c)
        D_fake_loss = BCE_loss(D_fake_out, y_fake)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()

        ## Generator Training
        for param in D.parameters():
            param.grad = None

        rand_v = torch.randn((mini_batch, v_dim, 1, 1), device = device)
        rand_c = get_codes(mini_batch).view(mini_batch, c_dim, 1, 1).float()
        X_fake = G(rand_v, rand_c)
        D_out = D(X_fake, rand_c)
        y_targ = torch.ones((mini_batch,1,1,1), device = device) #G gets low loss when D returns X_fake near 1
        G_loss = BCE_loss(D_out, y_targ)

        ## Loss combination
        model_loss = G_loss
        model_loss.backward()
        G_optimizer.step()

        D_losses.append(D_loss.data.item())
        G_losses.append(G_loss.data.item())

    if epoch % print_stride == 0:
        print('Epoch {} - loss_D: {:.3f}, loss_G: {:.3f}'.format((epoch),
                                                               torch.mean(torch.FloatTensor(D_losses)),
                                                               torch.mean(torch.FloatTensor(G_losses))))

        G_loss_tracker.append(torch.mean(torch.FloatTensor(G_losses)))
        D_loss_tracker.append(torch.mean(torch.FloatTensor(D_losses)))
        print_g_sample()



torch.save(G.state_dict(), f'c_gan_cifar_{exp_num}_G.pt')
return D, G



