Unsupervised representational learning using DCGAN

In [53]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os 
import torchvision.utils as vutils
from torchvision.datasets import ImageFolder

In [2]:
class Discriminator(nn.Module):
    def __init__(self,channel_img,features_d):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channel_img,features_d,kernel_size=4,stride=2,padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d,features_d*2,4,2,1),
            self._block(features_d*2,features_d*4,4,2,1),
            self._block(features_d*4,features_d*8,4,2,1),
            nn.Conv2d(features_d*8,1,4,2,0),
            nn.Sigmoid(),
        )
            
    def _block(self,in_channels,out_channels,kernel_size,stride,padding,bias=False):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self,x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def  __init__(self,channels_noise,img_dim,features_g):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
            self._block(channels_noise,features_g*16,4,1,0),
            self._block(features_g*16,features_g*8,4,2,1),
            self._block(features_g*8,features_g*4,4,2,1),
            self._block(features_g*4,features_g*2,4,2,1),
            nn.ConvTranspose2d(features_g*2,img_dim,4,2,1),
            nn.Tanh(),
        )
    
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
#             nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        
    def forward(self,x):
        return self.gen(x)

In [4]:
# initalizing weights of the model with mean 0 and std dev 1
def initialize_weights(model):
    for m in model.modules():
#         print(m)
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [5]:
def test():
    N, in_channels, H, W = 8,3,64,64
    noise =100
    X = torch.randn((N,in_channels,H,W))
    disc = Discriminator(in_channels,8)
    initialize_weights(disc)
    assert disc(X).shape == (N,1,1,1), "Disc Failed"
    gen = Generator(noise,in_channels,8)
    z = torch.randn((N,noise,1,1))
    assert gen(z).shape == (N,in_channels,H,W), "gen Failed"
    print("success")

In [6]:
test()

success


In [None]:
def show_images(images, nmax=32):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=32):
    for images, _ in dl:
        show_images(images, nmax)
        break

stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

In [96]:
#hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 64
features_gen =64
features_disc = 64
batch_size = 128
noise_dim = 100
epochs=60
channels = 3
lr=2e-4

In [97]:
transform = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size), 
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(channels)], [0.5 for _ in range(channels)]
        ),
    ]
)

In [98]:
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transform, download=True
# )

dataset = ImageFolder(r"C:\Users\dhruv\Untitled Folder\Fake Face Image Generator\celeb_dataset", transform=transform)

dataloader = DataLoader(dataset, batch_size, shuffle=True)
gen = Generator(noise_dim, channels, features_gen).to(device)
disc = Discriminator(channels, features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)

In [None]:
show_batch(dataloader)

In [99]:
opt_gen = optim.Adam(gen.parameters(), lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [100]:
# Replace 'output_directory' with your desired directory path
output_directory = 'generated_images_celeb'
os.makedirs(output_directory, exist_ok=True)

In [101]:
fixed_noise = torch.randn(32, noise_dim, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0
gen.train()
disc.train()
for epoch in range(epochs):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(batch_size, noise_dim, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
       # Print losses occasionally and save a grid of generated fake images
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)

                # Save a grid of generated fake images
                fake_grid = vutils.make_grid(fake, normalize=True, padding=2, nrow=8)  # Adjust nrow as needed

                image_filename = os.path.join(output_directory, f'fake_images_epoch{epoch}_batch{batch_idx}.png')
                vutils.save_image(fake_grid, image_filename)

                # Display real and fake images in TensorBoard
                img_grid_real = vutils.make_grid(real[:32], normalize=True, padding=2, nrow=8)
                img_grid_fake = vutils.make_grid(fake[:32], normalize=True, padding=2, nrow=8)
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1


Epoch [0/60] Batch 0/1583                   Loss D: 0.7031, loss G: 0.7923
Epoch [0/60] Batch 100/1583                   Loss D: 0.1970, loss G: 2.5234
Epoch [0/60] Batch 200/1583                   Loss D: 0.7020, loss G: 2.0143
Epoch [0/60] Batch 300/1583                   Loss D: 0.4471, loss G: 1.8668
Epoch [0/60] Batch 400/1583                   Loss D: 0.5283, loss G: 1.9138
Epoch [0/60] Batch 500/1583                   Loss D: 0.4387, loss G: 2.2269
Epoch [0/60] Batch 600/1583                   Loss D: 0.4932, loss G: 1.8804
Epoch [0/60] Batch 700/1583                   Loss D: 0.5949, loss G: 2.7632
Epoch [0/60] Batch 800/1583                   Loss D: 0.4250, loss G: 2.7383
Epoch [0/60] Batch 900/1583                   Loss D: 0.4614, loss G: 1.8955
Epoch [0/60] Batch 1000/1583                   Loss D: 0.5312, loss G: 2.7696
Epoch [0/60] Batch 1100/1583                   Loss D: 0.5805, loss G: 1.6545
Epoch [0/60] Batch 1200/1583                   Loss D: 0.3902, loss G: 2.220

Epoch [6/60] Batch 1100/1583                   Loss D: 0.4517, loss G: 1.2365
Epoch [6/60] Batch 1200/1583                   Loss D: 0.4597, loss G: 0.9659
Epoch [6/60] Batch 1300/1583                   Loss D: 0.4894, loss G: 1.3598
Epoch [6/60] Batch 1400/1583                   Loss D: 0.5204, loss G: 0.9864
Epoch [6/60] Batch 1500/1583                   Loss D: 0.5250, loss G: 2.3587
Epoch [7/60] Batch 0/1583                   Loss D: 0.5777, loss G: 1.3001
Epoch [7/60] Batch 100/1583                   Loss D: 0.4179, loss G: 1.4337
Epoch [7/60] Batch 200/1583                   Loss D: 0.5126, loss G: 1.3247
Epoch [7/60] Batch 300/1583                   Loss D: 0.3993, loss G: 1.3537
Epoch [7/60] Batch 400/1583                   Loss D: 0.4899, loss G: 1.9841
Epoch [7/60] Batch 500/1583                   Loss D: 0.4040, loss G: 1.9066
Epoch [7/60] Batch 600/1583                   Loss D: 0.4227, loss G: 1.8564
Epoch [7/60] Batch 700/1583                   Loss D: 0.4925, loss G: 0.9

Epoch [13/60] Batch 500/1583                   Loss D: 0.2555, loss G: 2.5009
Epoch [13/60] Batch 600/1583                   Loss D: 0.2802, loss G: 2.8691
Epoch [13/60] Batch 700/1583                   Loss D: 0.3607, loss G: 4.4400
Epoch [13/60] Batch 800/1583                   Loss D: 0.2559, loss G: 2.0844
Epoch [13/60] Batch 900/1583                   Loss D: 0.3143, loss G: 4.2523
Epoch [13/60] Batch 1000/1583                   Loss D: 1.2595, loss G: 0.8554
Epoch [13/60] Batch 1100/1583                   Loss D: 0.4009, loss G: 2.1441
Epoch [13/60] Batch 1200/1583                   Loss D: 0.2580, loss G: 2.9516
Epoch [13/60] Batch 1300/1583                   Loss D: 0.2410, loss G: 3.4109
Epoch [13/60] Batch 1400/1583                   Loss D: 0.2307, loss G: 2.3325
Epoch [13/60] Batch 1500/1583                   Loss D: 0.1684, loss G: 1.8872
Epoch [14/60] Batch 0/1583                   Loss D: 0.1417, loss G: 2.7801
Epoch [14/60] Batch 100/1583                   Loss D: 0.254

Epoch [19/60] Batch 1400/1583                   Loss D: 0.1885, loss G: 4.1249
Epoch [19/60] Batch 1500/1583                   Loss D: 0.1616, loss G: 2.6096
Epoch [20/60] Batch 0/1583                   Loss D: 1.3702, loss G: 5.0757
Epoch [20/60] Batch 100/1583                   Loss D: 0.1859, loss G: 3.3347
Epoch [20/60] Batch 200/1583                   Loss D: 0.2741, loss G: 1.1927
Epoch [20/60] Batch 300/1583                   Loss D: 0.1149, loss G: 4.2487
Epoch [20/60] Batch 400/1583                   Loss D: 0.2022, loss G: 2.8287
Epoch [20/60] Batch 500/1583                   Loss D: 0.1015, loss G: 3.8814
Epoch [20/60] Batch 600/1583                   Loss D: 0.1125, loss G: 2.9151
Epoch [20/60] Batch 700/1583                   Loss D: 0.1263, loss G: 3.0539
Epoch [20/60] Batch 800/1583                   Loss D: 0.1323, loss G: 2.8643
Epoch [20/60] Batch 900/1583                   Loss D: 0.2894, loss G: 1.4736
Epoch [20/60] Batch 1000/1583                   Loss D: 0.2123, 

Epoch [26/60] Batch 700/1583                   Loss D: 0.1031, loss G: 4.3485
Epoch [26/60] Batch 800/1583                   Loss D: 0.1591, loss G: 4.1837
Epoch [26/60] Batch 900/1583                   Loss D: 0.1233, loss G: 5.2720
Epoch [26/60] Batch 1000/1583                   Loss D: 0.0762, loss G: 4.7056
Epoch [26/60] Batch 1100/1583                   Loss D: 0.0627, loss G: 3.4586
Epoch [26/60] Batch 1200/1583                   Loss D: 0.0708, loss G: 3.8858
Epoch [26/60] Batch 1300/1583                   Loss D: 0.1294, loss G: 3.0261
Epoch [26/60] Batch 1400/1583                   Loss D: 0.0609, loss G: 4.2486
Epoch [26/60] Batch 1500/1583                   Loss D: 0.3488, loss G: 6.3006
Epoch [27/60] Batch 0/1583                   Loss D: 0.2771, loss G: 2.7149
Epoch [27/60] Batch 100/1583                   Loss D: 0.0674, loss G: 3.7253
Epoch [27/60] Batch 200/1583                   Loss D: 0.1135, loss G: 4.0328
Epoch [27/60] Batch 300/1583                   Loss D: 0.097

Epoch [33/60] Batch 0/1583                   Loss D: 0.3394, loss G: 1.6158
Epoch [33/60] Batch 100/1583                   Loss D: 0.3819, loss G: 0.9508
Epoch [33/60] Batch 200/1583                   Loss D: 0.3004, loss G: 9.0620
Epoch [33/60] Batch 300/1583                   Loss D: 0.0224, loss G: 4.4509
Epoch [33/60] Batch 400/1583                   Loss D: 0.1456, loss G: 3.2879
Epoch [33/60] Batch 500/1583                   Loss D: 0.0321, loss G: 5.0528
Epoch [33/60] Batch 600/1583                   Loss D: 0.0660, loss G: 4.3117
Epoch [33/60] Batch 700/1583                   Loss D: 0.0606, loss G: 3.4931
Epoch [33/60] Batch 800/1583                   Loss D: 0.0526, loss G: 3.4857
Epoch [33/60] Batch 900/1583                   Loss D: 0.0888, loss G: 5.8551
Epoch [33/60] Batch 1000/1583                   Loss D: 0.0477, loss G: 3.1534
Epoch [33/60] Batch 1100/1583                   Loss D: 0.1393, loss G: 2.7919
Epoch [33/60] Batch 1200/1583                   Loss D: 0.0702, 

Epoch [39/60] Batch 900/1583                   Loss D: 0.0615, loss G: 4.7838
Epoch [39/60] Batch 1000/1583                   Loss D: 0.0463, loss G: 4.7987
Epoch [39/60] Batch 1100/1583                   Loss D: 0.1040, loss G: 5.8268
Epoch [39/60] Batch 1200/1583                   Loss D: 0.0694, loss G: 4.9897
Epoch [39/60] Batch 1300/1583                   Loss D: 0.0414, loss G: 4.8090
Epoch [39/60] Batch 1400/1583                   Loss D: 0.0638, loss G: 4.9757
Epoch [39/60] Batch 1500/1583                   Loss D: 0.3727, loss G: 1.7110
Epoch [40/60] Batch 0/1583                   Loss D: 0.1674, loss G: 5.3018
Epoch [40/60] Batch 100/1583                   Loss D: 0.1229, loss G: 6.2356
Epoch [40/60] Batch 200/1583                   Loss D: 0.0597, loss G: 4.1685
Epoch [40/60] Batch 300/1583                   Loss D: 0.0450, loss G: 4.8095
Epoch [40/60] Batch 400/1583                   Loss D: 0.1307, loss G: 4.4967
Epoch [40/60] Batch 500/1583                   Loss D: 0.080

Epoch [46/60] Batch 200/1583                   Loss D: 0.1096, loss G: 3.7298
Epoch [46/60] Batch 300/1583                   Loss D: 0.0395, loss G: 5.0388
Epoch [46/60] Batch 400/1583                   Loss D: 0.0312, loss G: 4.6860
Epoch [46/60] Batch 500/1583                   Loss D: 0.0606, loss G: 4.3444
Epoch [46/60] Batch 600/1583                   Loss D: 0.1120, loss G: 1.9985
Epoch [46/60] Batch 700/1583                   Loss D: 0.0256, loss G: 4.7711
Epoch [46/60] Batch 800/1583                   Loss D: 0.0881, loss G: 4.3782
Epoch [46/60] Batch 900/1583                   Loss D: 0.0856, loss G: 5.8023
Epoch [46/60] Batch 1000/1583                   Loss D: 0.0210, loss G: 5.6663
Epoch [46/60] Batch 1100/1583                   Loss D: 0.0500, loss G: 4.0793
Epoch [46/60] Batch 1200/1583                   Loss D: 0.0310, loss G: 4.7894
Epoch [46/60] Batch 1300/1583                   Loss D: 0.0744, loss G: 6.0457
Epoch [46/60] Batch 1400/1583                   Loss D: 0.02

Epoch [52/60] Batch 1100/1583                   Loss D: 0.0174, loss G: 5.2123
Epoch [52/60] Batch 1200/1583                   Loss D: 0.0648, loss G: 4.4279
Epoch [52/60] Batch 1300/1583                   Loss D: 0.0309, loss G: 5.2249
Epoch [52/60] Batch 1400/1583                   Loss D: 0.0518, loss G: 4.2508
Epoch [52/60] Batch 1500/1583                   Loss D: 0.0397, loss G: 5.7068
Epoch [53/60] Batch 0/1583                   Loss D: 0.0239, loss G: 5.1803
Epoch [53/60] Batch 100/1583                   Loss D: 0.0166, loss G: 5.8034
Epoch [53/60] Batch 200/1583                   Loss D: 0.0186, loss G: 5.4656
Epoch [53/60] Batch 300/1583                   Loss D: 0.0482, loss G: 7.4452
Epoch [53/60] Batch 400/1583                   Loss D: 0.0372, loss G: 4.9104
Epoch [53/60] Batch 500/1583                   Loss D: 0.3731, loss G: 2.5949
Epoch [53/60] Batch 600/1583                   Loss D: 0.0427, loss G: 4.8354
Epoch [53/60] Batch 700/1583                   Loss D: 0.0378

Epoch [59/60] Batch 400/1583                   Loss D: 0.0863, loss G: 5.5390
Epoch [59/60] Batch 500/1583                   Loss D: 0.0556, loss G: 4.0625
Epoch [59/60] Batch 600/1583                   Loss D: 0.0435, loss G: 4.9962
Epoch [59/60] Batch 700/1583                   Loss D: 0.0428, loss G: 5.2083
Epoch [59/60] Batch 800/1583                   Loss D: 0.0376, loss G: 4.8994
Epoch [59/60] Batch 900/1583                   Loss D: 0.0134, loss G: 5.5977
Epoch [59/60] Batch 1000/1583                   Loss D: 0.0684, loss G: 4.6603
Epoch [59/60] Batch 1100/1583                   Loss D: 0.3845, loss G: 3.3492
Epoch [59/60] Batch 1200/1583                   Loss D: 1.5435, loss G: 0.1609
Epoch [59/60] Batch 1300/1583                   Loss D: 0.1666, loss G: 8.7913
Epoch [59/60] Batch 1400/1583                   Loss D: 0.0621, loss G: 4.9030
Epoch [59/60] Batch 1500/1583                   Loss D: 0.0749, loss G: 4.5327


In [103]:
torch.save(disc.state_dict(), "Discriminator_celeb2")
torch.save(disc.state_dict(), "Generator_celeb2")