In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and preprocess images
class CustomDataset(Dataset):
    def __init__(self, folder_path, img_size=(64, 64)):
        self.folder_path = folder_path
        self.img_size = img_size
        self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Corrected for RGB
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.folder_path, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        return img

# Dataset and DataLoader
folder_path = r"dataset-path"
dataset = CustomDataset(folder_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)

# Define Generator with deeper architecture
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512 * 4 * 4),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (512, 4, 4)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Define Discriminator with deeper architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# Weight initialization function
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Initialize models
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Apply weight initialization
generator.apply(weights_init)
discriminator.apply(weights_init)

# Loss and Optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training the GAN with label smoothing
def train_gan(epochs):  # Removed unused batch_size parameter
    for epoch in range(epochs):
        for real_imgs in dataloader:
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            real_labels = torch.full((batch_size, 1), 0.9, device=device)  # Label smoothing
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Train Discriminator
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)

            d_loss_real = criterion(discriminator(real_imgs), real_labels)
            d_loss_fake = criterion(discriminator(fake_imgs.detach()), fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{epochs}] - D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")

# Train the GAN
train_gan(epochs=100)

Epoch [1/100] - D loss: 0.4655, G loss: 1.6617
Epoch [2/100] - D loss: 0.4990, G loss: 5.0388
Epoch [3/100] - D loss: 0.5301, G loss: 1.0760
Epoch [4/100] - D loss: 0.3088, G loss: 2.6131
Epoch [5/100] - D loss: 0.2583, G loss: 3.5286
Epoch [6/100] - D loss: 0.2782, G loss: 2.5647
Epoch [7/100] - D loss: 0.2781, G loss: 2.9377
Epoch [8/100] - D loss: 0.2065, G loss: 3.7616
Epoch [9/100] - D loss: 0.4356, G loss: 1.3402
Epoch [10/100] - D loss: 0.2165, G loss: 3.7045
Epoch [11/100] - D loss: 0.2713, G loss: 3.8153
Epoch [12/100] - D loss: 0.2571, G loss: 5.5170
Epoch [13/100] - D loss: 0.1951, G loss: 4.7306
Epoch [14/100] - D loss: 0.3697, G loss: 5.2896
Epoch [15/100] - D loss: 0.2202, G loss: 3.5896
Epoch [16/100] - D loss: 0.3247, G loss: 6.3423
Epoch [17/100] - D loss: 0.2114, G loss: 4.2446
Epoch [18/100] - D loss: 0.3541, G loss: 4.8522
Epoch [19/100] - D loss: 0.2000, G loss: 4.7820
Epoch [20/100] - D loss: 0.2323, G loss: 2.7788
Epoch [21/100] - D loss: 0.2276, G loss: 3.7492
E

In [2]:
# Save the trained models
torch.save(generator.state_dict(), "generator1.pth")
torch.save(discriminator.state_dict(), "discriminator1.pth")
print("Models saved successfully.")

Models saved successfully.


In [5]:
import torchvision.utils as vutils
# Function to generate and display synthetic images
def visualize_generated_images(generator, latent_dim, examples=10):
    generator.eval()
    noise = torch.randn(examples, latent_dim, device=device)
    generated_images = generator(noise).cpu().detach()
    generated_images = (generated_images + 1) / 2.0  # Rescale to [0,1]
    vutils.save_image(generated_images, "generated.png", nrow=10)
    print("Generated images saved as generated.png")

In [6]:
# After training, visualize the outputs
visualize_generated_images(generator, latent_dim)

Generated images saved as generated.png
