In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary


from random import randint

from IPython.display import Image
from IPython.core.display import Image, display

from vae import Flatten, UnFlatten, VAE

  from IPython.core.display import Image, display


In [77]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_vectors = torch.load('/home/ashish/Projects/VAE-GAN/Notebooks/latent_vectors_final.pt').to(device)

In [78]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [79]:
device = get_default_device()
device

device(type='cuda')

In [80]:
torch.cuda.is_available()

True

In [81]:
print(torch.backends.cudnn.enabled)

True


In [82]:
#train_dl = DeviceDataLoader(train_dl, device)

In [83]:
import torch.nn as nn

In [84]:
discriminator = nn.Sequential(
    # in: 256 x 1 x 1

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=2, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 1 x 1

    nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=2, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 1 x 1

    nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=2, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 1 x 1

    nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=2, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 1 x 1

    nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=2, bias=False),
    # out: 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())

In [85]:
discriminator = to_device(discriminator, device)

In [86]:
layer = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=2, bias=False)

In [87]:
layer_fin = nn.Sequential(layer,    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True))

In [88]:
input = torch.randn(20, 256, 1, 1)

In [89]:
output = layer(input)

In [90]:
output.shape

torch.Size([20, 512, 1, 1])

In [91]:
layer.weight.shape

torch.Size([512, 256, 4, 4])

In [92]:
#output = discriminator(input)

In [93]:
output.shape

torch.Size([20, 512, 1, 1])

In [94]:
latent_size = 128
layer1_input = latent_size
layer1_output = 512
kernel_size = 4
stride = 1
padding = 0

In [95]:
layer1 = nn.ConvTranspose2d(latent_size, layer1_output, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


In [96]:
input1 = torch.randn(20, latent_size, 1, 1)


In [97]:
output1 = layer1(input1)

In [98]:
output1.shape

torch.Size([20, 512, 4, 4])

In [99]:
latent_size = 128
layer2_input = 512
layer2_output = 512
kernel_size = 4
stride = 1
padding = 0

In [100]:
layer2 = nn.ConvTranspose2d(layer2_input, layer2_output, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


In [101]:
input2 = torch.randn(20, layer2_input, output1.shape[2], output1.shape[2])


In [102]:
output2 = layer2(input2)

In [103]:
output2.shape

torch.Size([20, 512, 7, 7])

In [104]:
latent_size = 128
layer3_input = 512
layer3_output = 512
kernel_size = 4
stride = 1
padding = 0

In [105]:
layer3 = nn.ConvTranspose2d(layer3_input, layer3_output, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


In [106]:
input3 = torch.randn(20, layer3_input, output2.shape[2], output2.shape[2])


In [107]:
output3 = layer3(input3)

In [108]:
output3.shape

torch.Size([20, 512, 10, 10])

In [109]:
latent_size = 128
layer4_input = 512
layer4_output = 256
kernel_size = 4
stride = 2
padding = 0

In [110]:
layer4 = nn.Conv2d(layer4_input, layer4_output, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


In [111]:
input4 = torch.randn(20, layer4_input, output3.shape[2], output3.shape[2])


In [112]:
output4 = layer4(input4)

In [113]:
output4.shape

torch.Size([20, 256, 4, 4])

In [114]:
latent_size = 128
layer5_input = 256
layer5_output = 256
kernel_size = 4
stride = 2
padding = 0

In [115]:
layer5 = nn.Conv2d(layer5_input, layer5_output, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


In [116]:
input5 = torch.randn(20, layer5_input, output4.shape[2], output4.shape[2])


In [117]:
output5 = layer5(input5)

In [118]:
output5.shape

torch.Size([20, 256, 1, 1])

In [119]:
generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(512, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(512, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=0, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=0, bias=False),
    nn.Tanh()

    #nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
    #nn.Tanh()
    # out: 3 x 64 x 64
)

In [120]:
input1 = torch.randn(20, latent_size, 1, 1)

In [121]:
output_generator = generator(input1)


In [122]:
output_generator.shape

torch.Size([20, 256, 1, 1])

In [123]:
batch_size = 20
xb = torch.randn(batch_size, latent_size, 1, 1) # random latent tensors
fake_images = generator(xb)
print(fake_images.shape)
#show_images(fake_images)

torch.Size([20, 256, 1, 1])


In [124]:
generator = to_device(generator, device)

In [125]:
def train_discriminator(real_images, opt_d):
    # Clear discriminator gradients
    opt_d.zero_grad()

    # Pass real images through discriminator
    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # Generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)

    # Pass fake images through discriminator
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    # Update discriminator weights
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score

In [126]:
def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)
    
    # Try to fool the discriminator
    preds = discriminator(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    # Update generator weights
    loss.backward()
    opt_g.step()
    
    return loss.item()

In [127]:
from torchvision.utils import save_image

In [128]:
import os

'''DATA_DIR = './animefacedataset'
print(os.listdir(DATA_DIR))'''

"DATA_DIR = './animefacedataset'\nprint(os.listdir(DATA_DIR))"

In [129]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

In [130]:
'''def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(fake_images, os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))'''

"def save_samples(index, latent_tensors, show=True):\n    fake_images = generator(latent_tensors)\n    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)\n    save_image(fake_images, os.path.join(sample_dir, fake_fname), nrow=8)\n    print('Saving', fake_fname)\n    if show:\n        fig, ax = plt.subplots(figsize=(8, 8))\n        ax.set_xticks([]); ax.set_yticks([])\n        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))"

In [131]:
def save_samples2(index, latent_tensors):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(fake_images, os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
  

In [132]:
from torchvision.utils import save_image

In [133]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

In [134]:
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [135]:
latent_vectors.shape


torch.Size([10000, 256])

In [136]:
latent_vectors = latent_vectors.view(latent_vectors.shape[0], latent_vectors.shape[1], 1, 1)

In [137]:
fixed_latent = torch.randn(batch_size, latent_size, 1, 1, device=device)

In [138]:
dataset = TensorDataset(latent_vectors)

In [139]:
train_dl = DataLoader(latent_vectors, batch_size, shuffle=True, num_workers=0)

In [147]:
train_dl = DeviceDataLoader(train_dl, device)

In [148]:
train_dl

<__main__.DeviceDataLoader at 0x7fb740820af0>

In [142]:
'''train_dl = latent_vectors.reshape(int(latent_vectors.shape[0]/batch_size), batch_size, 256, 1,1)
print('train_dl.shape', train_dl.shape)'''

"train_dl = latent_vectors.reshape(int(latent_vectors.shape[0]/batch_size), batch_size, 256, 1,1)\nprint('train_dl.shape', train_dl.shape)"

In [149]:
def fit(epochs, lr, start_idx=1):
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Create optimizers
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images in tqdm(train_dl):
            # Train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            # Train generator
            loss_g = train_generator(opt_g)
            
        # Record losses & scores
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        
        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
        # Save generated images
        #save_samples2(epoch+start_idx, fixed_latent)
    
    return losses_g, losses_d, real_scores, fake_scores

In [150]:
lr = 0.0002
epochs = 25

In [151]:
for real_images in tqdm(train_dl):
  print(real_images.shape)
  break

  0%|          | 0/500 [00:00<?, ?it/s]

torch.Size([20, 256, 1, 1])


In [None]:
history = fit(epochs, lr)

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/25], loss_g: 1.4707, loss_d: 0.9623, real_score: 0.6057, fake_score: 0.3349


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [2/25], loss_g: 2.4363, loss_d: 0.2424, real_score: 0.8908, fake_score: 0.1165


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [3/25], loss_g: 2.7463, loss_d: 0.1901, real_score: 0.9494, fake_score: 0.1234


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [4/25], loss_g: 3.3747, loss_d: 0.3420, real_score: 0.7731, fake_score: 0.0529


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [5/25], loss_g: 3.0394, loss_d: 0.0754, real_score: 0.9744, fake_score: 0.0479


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [6/25], loss_g: 3.8388, loss_d: 0.0382, real_score: 0.9839, fake_score: 0.0215


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [7/25], loss_g: 4.4339, loss_d: 0.0145, real_score: 0.9974, fake_score: 0.0119


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch [8/25], loss_g: 4.9456, loss_d: 0.0090, real_score: 0.9982, fake_score: 0.0071


  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
losses_g, losses_d, real_scores, fake_scores = history

In [None]:
# Save the model checkpoints 
torch.save(generator.state_dict(), 'G.pth')
torch.save(discriminator.state_dict(), 'D.pth')

In [None]:
plt.plot(losses_d, '-')
plt.plot(losses_g, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses');

In [None]:
plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('epoch')
plt.ylabel('score')
plt.legend(['Real', 'Fake'])
plt.title('Scores');