In [None]:
# Basic setup and imports for question 4

from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image

%matplotlib inline
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# sets seed for reproducibility
torch.manual_seed(250)
torch.cuda.manual_seed(250)
np.random.seed(250)

ROOT = ""

In [None]:
# Define variables
CUDA = False
DATA_PATH = './data'
batch_size = 128
epochs = 25
lr = 1e-4
classes = 4
channels = 1
img_size = 48
latent_dim = 100
log_interval = 100

In [None]:
class CGANImageDataset(Dataset):
    '''Simple custom dataset to load in the 4 tree classes without the background class.'''
    def __init__(self, root, image_dirs, transform=None):
        self.root = root
        self.transform = transform
        self.all_images = []
        self.labels = []
        i = 0
        for dir in image_dirs:
            dir_path = root + "/" + image_dirs[i]
            self.all_images += [img for img in os.listdir(dir_path) if img.endswith(".tif")]
            self.labels += [i for img in os.listdir(dir_path) if img.endswith(".tif")]
            i+=1

    def __len__(self):
        return len(self.all_img)

    def __getitem__(self, idx):
        label = self.labels[idx] 
        image = Image.open(self.root + "/class_" + str(label+1) + "/" + self.all_images[idx])

        if self.transform:
            image = self.transform(image)

        return image, label

CGAN_data = CGANImageDataset(ROOT + 'images', ["class_1", "class_2", "class_3", "class_4"], transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), 
                                                                                transforms.Normalize(0.5,0.5), transforms.Resize(img_size)]))                                                      

cgan_y_train = CGAN_data.labels
num_class_gan_train = Counter(cgan_y_train)
cgan_class_sample_count = np.array([v for _, v in sorted(num_class_gan_train.items())])
cgan_weight = 1. / torch.tensor(cgan_class_sample_count).float()
cgan_samples_weight = np.array([cgan_weight[t] for t in cgan_y_train])
cgan_samples_weight = torch.from_numpy(cgan_samples_weight)
sampler = torch.utils.data.WeightedRandomSampler(cgan_samples_weight, len(cgan_samples_weight))
CGAN_dataloader=torch.utils.data.DataLoader(CGAN_data, batch_size=128, sampler = sampler, drop_last=True)

# Prints and example of a batch
images, labels = next(iter(CGAN_dataloader))
figure = plt.figure(figsize=(6, 6))
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.2, hspace=0.8)
cols, rows = 8, 8
for i in range(cols * rows):
    figure.add_subplot(rows, cols, i+1)
    plt.title(labels[i].item())
    plt.axis("off")
    plt.imshow(images[i,:].squeeze(), cmap="gray")
plt.show()

In [None]:
CUDA = CUDA and torch.cuda.is_available()
print("PyTorch version: {}".format(torch.__version__))
if CUDA:
    print("CUDA version: {}\n".format(torch.version.cuda))

if CUDA:
    torch.cuda.manual_seed(seed)
device = torch.device("cuda:0" if CUDA else "cpu")
cudnn.benchmark = True

In [None]:
class Generator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Generator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)

        self.model = nn.Sequential(
            *self._create_layer(self.latent_dim + self.classes, 128, False),
            *self._create_layer(128, 256),
            *self._create_layer(256, 512),
            *self._create_layer(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def _create_layer(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, noise, labels):
        z = torch.cat((self.label_embedding(labels), noise), -1)
        x = self.model(z)
        x = x.view(x.size(0), *self.img_shape)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Discriminator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = torch.nn.BCELoss()

        self.model = nn.Sequential(
            *self._create_layer(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer(1024, 512, True, True),
            *self._create_layer(512, 256, True, True),
            *self._create_layer(256, 128, False, False),
            *self._create_layer(128, 1, False, False),
            nn.Sigmoid()
        )

    def _create_layer(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, image, labels):
        x = torch.cat((image.view(image.size(0), -1), self.label_embedding(labels)), -1)
        return self.model(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [None]:
# Setup the generator and the discriminator
netG = Generator(classes, channels, img_size, latent_dim).to(device)
print(netG)
netD = Discriminator(classes, channels, img_size, latent_dim).to(device)
print(netD)

# Setup Adam optimizers for both G and D
optim_D = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optim_G = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# Train
from torch.autograd import Variable
img_list = []

netG.train()
netD.train()
viz_z = torch.zeros((batch_size, latent_dim), device=device)
viz_noise = torch.randn(batch_size, latent_dim, device=device)
nrows = batch_size // 8
viz_label = torch.LongTensor(np.array([num for _ in range(nrows) for num in range(8)])).to(device)

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(CGAN_dataloader):
        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)
        real_label = torch.full((batch_size, 1), 1., device=device)
        fake_label = torch.full((batch_size, 1), 0., device=device)

        # Train G
        netG.zero_grad()
        z_noise = torch.randn(batch_size, latent_dim, device=device)
        x_fake_labels = torch.randint(0, classes, (batch_size,), device=device)
        x_fake = netG(z_noise, x_fake_labels)
        y_fake_g = netD(x_fake, x_fake_labels)
        g_loss = netD.loss(y_fake_g, real_label)
        g_loss.backward()
        optim_G.step()

        # Train D
        netD.zero_grad()
        y_real = netD(data, target)
        d_real_loss = netD.loss(y_real, real_label)
        y_fake_d = netD(x_fake.detach(), x_fake_labels)
        d_fake_loss = netD.loss(y_fake_d, fake_label)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optim_D.step()
        
        if batch_idx % log_interval == 0 and batch_idx > 0:
            print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f}'.format(
                        epoch, batch_idx, len(CGAN_dataloader),
                        d_loss.mean().item(),
                        g_loss.mean().item()))
            
            with torch.no_grad():
                viz_sample = netG(viz_noise, viz_label)
                img_list.append(vutils.make_grid(viz_sample, normalize=True))

    # Prints an example image from each of the 4 classes
    z = Variable(torch.randn(4,100)).to(device)
    labels = Variable(torch.LongTensor(np.arange(4))).to(device)
    plt.figure(figsize=[3, 3])
    sample_images = netG(z, labels)
    grid = vutils.make_grid(sample_images, nrow=4, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()

In [None]:
%matplotlib inline 
netG.eval()

p = 1
figure = plt.figure(figsize=(5, 5))
cols, rows = 4, 4
for label in range(rows):
    z = torch.randn(cols, 100, device=device)
    labels = torch.LongTensor(np.array([label for _ in range(cols)])).to(device)
    images = netG(z,labels)
    for image in images:
        figure.add_subplot(rows, cols, p)
        plt.imshow(image.cpu().detach().squeeze().reshape(48, 48), cmap="gray")
        plt.axis('off')
        p+=1
plt.show()