#Imports

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import torchvision.utils as vutils


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#Download data


In [None]:
!gdown #google drive file id

In [None]:
!unzip "/content/images.zip" -d "../"

In [None]:
#!rm -rf /content/images

In [None]:
data_folder = "/content/images"

#Prepare data


In [None]:
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5])
])

In [None]:
dataset = ImageFolder(root=data_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

##Check data

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.axis("off")
  plt.show()

dataiter = iter(dataloader)
images, _ = next(dataiter)

imshow(torchvision.utils.make_grid(images[:16], nrow=4))

#DCGAN setup

In [None]:
import torch.nn as nn

In [None]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0)

##Generator

In [None]:
class Generator(nn.Module):
  def __init__(self, nz, ngf, nc):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
      nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
      nn.BatchNorm2d(ngf * 8),
      nn.ReLU(True),

      nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf * 4),
      nn.ReLU(True),

      nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf * 2),
      nn.ReLU(True),

      nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf),
      nn.ReLU(True),

      nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
      nn.Tanh()
    )

  def forward(self, input):
    return self.main(input)

##Discriminator

In [None]:
class Discriminator(nn.Module):
  def __init__(self, nc, ndf):
    super(Discriminator, self).__init__()
    self.main = nn.Sequential(
      nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 2),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 4),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ndf * 8),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
      nn.Sigmoid()
    )

  def forward(self, input):
    return self.main(input)

##Hyperparameters and models initialization

In [None]:
nz = 100     # Latent vector size
ngf = 64     # Generator feature map size
ndf = 64     # Discriminator feature map size
nc = 3       # Number of channels

netG = Generator(nz, ngf, nc).to(device)
netG.apply(weights_init)

netD = Discriminator(nc, ndf).to(device)
netD.apply(weights_init)

#Training

##Training setup

In [None]:
loss_fn = nn.BCELoss()

In [None]:
lr = 0.0002
beta1 = 0.5
num_epochs = 30
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

##Training loop

In [None]:
for epoch in range(num_epochs):
  for i, data in enumerate(dataloader):
    # Update Discriminator

    netD.zero_grad()

    # Real
    real_images = data[0].to(device)
    batch_size = real_images.size(0)
    real_labels = torch.ones(batch_size, device=device)
    output = netD(real_images).view(-1)
    lossD_real = loss_fn(output, real_labels)

    # Fake
    noise = torch.randn(batch_size, nz, 1, 1, device=device)
    fake_images = netG(noise)
    fake_labels = torch.zeros(batch_size, device=device)
    output = netD(fake_images.detach()).view(-1)
    lossD_fake = loss_fn(output, fake_labels)

    # Total loss
    lossD = lossD_real + lossD_fake
    lossD.backward()
    optimizerD.step()

    # Update Generator

    netG.zero_grad()
    output = netD(fake_images).view(-1)
    lossG = loss_fn(output, real_labels)
    lossG.backward()
    optimizerG.step()

    if i % 10 == 0:
      print(f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}] "
        f"Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}")

  with torch.no_grad():
    fake = netG(fixed_noise).detach().cpu()
  img_grid = vutils.make_grid(fake, padding=2, normalize=True)
  plt.figure(figsize=(8,8))
  plt.axis("off")
  plt.imshow(np.transpose(img_grid, (1, 2, 0)))
  plt.show()

  if (epoch + 1) % 5 == 0:
    torch.save(netG.state_dict(), f"generator_epoch_{epoch+1}.pth")
    torch.save(netD.state_dict(), f"discriminator_epoch_{epoch+1}.pth")

In [None]:
torch.save(netG.state_dict(), f"generator.pth")
torch.save(netD.state_dict(), f"discriminator.pth")

#Tests

In [None]:
netG = Generator(nz, ngf, nc).to(device)
netG.load_state_dict(torch.load("generator.pth", map_location=device))
netG.eval()

## Real vs Fake comparison

In [None]:
dataiter = iter(dataloader)
real_images, _ = next(dataiter)
real_images = real_images[:64].to(device)

In [None]:
noise = torch.randn(64, nz, 1, 1, device=device)
with torch.no_grad():
    fake_images = netG(noise).detach().cpu()

In [None]:
real_images_cpu = real_images.cpu() * 0.5 + 0.5
fake_images = fake_images * 0.5 + 0.5

In [None]:
grid_real = torchvision.utils.make_grid(real_images_cpu, nrow=8, padding=2)
grid_fake = torchvision.utils.make_grid(fake_images, nrow=8, padding=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

axes[0].imshow(np.transpose(grid_real.numpy(), (1, 2, 0)))
axes[0].axis('off')
axes[0].set_title("Real Images")

axes[1].imshow(np.transpose(grid_fake.numpy(), (1, 2, 0)))
axes[1].axis('off')
axes[1].set_title("Fake Images")

plt.show()

## Latent cycle gif

In [None]:
import imageio

In [None]:
z1 = torch.randn(nz, device=device)
z1_unit = z1

z2 = torch.randn(nz, device=device)
#z2 = z2 - (z2 @ z1_unit) * z1_unit
z2_unit = z2

In [None]:
num_frames = 60
frames = []
for i in range(num_frames):
    theta = 2 * np.pi * i / num_frames

    z_interp = (z1_unit * np.cos(theta) + z2_unit * np.sin(theta)).unsqueeze(0).unsqueeze(2).unsqueeze(3)
    with torch.no_grad():
        fake_img = netG(z_interp).cpu().squeeze(0)

    img = (fake_img * 0.5 + 0.5).permute(1, 2, 0).numpy()  # H×W×C, float in [0,1]
    img_uint8 = (img * 255).astype(np.uint8)
    frames.append(img_uint8)

In [None]:
gif_path = "latent_circle.gif"
imageio.mimsave(gif_path, frames, fps=10, loop=0)

In [None]:
from IPython.display import display, Image
display(Image(filename="latent_circle.gif"))