In [1]:
pip install torch torchvision numpy matplotlib scipy




In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from PIL import Image

# Set device and create output directory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directory for generated images
os.makedirs("generated_images", exist_ok=True)

# Hyperparameters
lr = 0.0002
batch_size = 64
image_size = 64
channels_img = 3  # RGB images
z_dim = 100
num_epochs = 10
features_gen = 64
features_disc = 64

# Dataset class for loading Yelp images
class YelpDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = []
        
        # Filter out invalid images during initialization
        for f in os.listdir(image_folder):
            if f.endswith(('.jpg', '.jpeg', '.png')):
                try:
                    img_path = os.path.join(image_folder, f)
                    with Image.open(img_path) as img:
                        img.verify()  # Verify image integrity
                    self.image_files.append(f)
                except Exception as e:
                    print(f"Skipping corrupted image {f}: {str(e)}")
        
        print(f"Found {len(self.image_files)} valid images")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Keep trying until a valid image is loaded
        while True:
            try:
                img_path = os.path.join(self.image_folder, self.image_files[idx])
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                return image
            except Exception as e:
                print(f"Error loading image {self.image_files[idx]}: {str(e)}")
                idx = (idx + 1) % len(self.image_files)  # Move to the next image

# Transformations for preprocessing the images
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]),
])

# Load the dataset
dataset = YelpDataset(image_folder='./a2_photos', transform=transform)

# Create data loader
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Using device: cpu
Skipping corrupted image 1MOGQBWogR8oJr1WgERi9g.jpg: cannot identify image file './a2_photos\\1MOGQBWogR8oJr1WgERi9g.jpg'
Skipping corrupted image 5q-sAvIPl0yNeuAbNBPM1g.jpg: cannot identify image file './a2_photos\\5q-sAvIPl0yNeuAbNBPM1g.jpg'
Skipping corrupted image 74upe0h6XxwgzqpdnAh_7Q.jpg: cannot identify image file './a2_photos\\74upe0h6XxwgzqpdnAh_7Q.jpg'
Skipping corrupted image B7xR9CuhRpP52PoehQHVow.jpg: cannot identify image file './a2_photos\\B7xR9CuhRpP52PoehQHVow.jpg'
Skipping corrupted image C6n0nKVbgLbYmxSiQ_bFsg.jpg: cannot identify image file './a2_photos\\C6n0nKVbgLbYmxSiQ_bFsg.jpg'
Skipping corrupted image CA9z96gGA4y9QOes2Y9eGw.jpg: cannot identify image file './a2_photos\\CA9z96gGA4y9QOes2Y9eGw.jpg'
Skipping corrupted image CBxmBYD_5CXIL_F-2PDqmA.jpg: cannot identify image file './a2_photos\\CBxmBYD_5CXIL_F-2PDqmA.jpg'
Skipping corrupted image GPMWGVjuCsa6fadnZsEplw.jpg: cannot identify image file './a2_photos\\GPMWGVjuCsa6fadnZsEplw.jpg'
Skippi

In [14]:
len(dataset)

52760

In [7]:
import torch
import torch.nn as nn

# Generator class definition
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        
        # Define the layers of the generator using a sequential block
        self.gen = nn.Sequential(
            self._block(z_dim, features_g * 16, 4, 1, 0),  # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # 32x32
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.Tanh(),  # Output layer activation (scaled to [-1, 1])
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        # Helper function to create the block
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),  # Use ReLU for activation
        )

    def forward(self, x):
        return self.gen(x)


In [8]:
import torch
import torch.nn as nn

# Discriminator class definition
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()

        # Define the layers of the discriminator using a sequential block
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),  # 32x32
            nn.LeakyReLU(0.2),  # Use LeakyReLU activation for better gradient flow
            self._block(features_d, features_d * 2, 4, 2, 1),  # 16x16
            self._block(features_d * 2, features_d * 4, 4, 2, 1),  # 8x8
            self._block(features_d * 4, features_d * 8, 4, 2, 1),  # 4x4
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),  # 1x1
            nn.Sigmoid(),  # Output layer with Sigmoid activation to output probabilities
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        # Helper function to create a block of layers for the discriminator
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),  # LeakyReLU with negative slope of 0.2
        )

    def forward(self, x):
        # Pass input through the discriminator
        return self.disc(x)


In [10]:
generator = Generator(z_dim, channels_img, features_gen).to(device)
discriminator = Discriminator(channels_img, features_disc).to(device)

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image

# Optimizers for the generator and discriminator
optimizer_g = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_d = optim.RMSprop(discriminator.parameters(), lr=lr)

# Training loop
num_epochs = 10
n_critic = 5  # Number of training steps for the discriminator

for epoch in range(num_epochs):
    for i, real_images in enumerate(loader):
        real_images = real_images.to(device)

        # Train the discriminator
        for _ in range(n_critic):
            optimizer_d.zero_grad()

            # Real images
            real_labels = torch.ones(real_images.size(0), 1).to(device)
            real_output = discriminator(real_images)
            d_loss_real = torch.mean(real_output)

            # Fake images
            noise = torch.randn(real_images.size(0), z_dim, 1, 1).to(device)
            fake_images = generator(noise)
            fake_labels = torch.zeros(real_images.size(0), 1).to(device)
            fake_output = discriminator(fake_images.detach())
            d_loss_fake = torch.mean(fake_output)

            # Total loss
            d_loss = d_loss_fake - d_loss_real
            d_loss.backward()

            # Weight clipping
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()

        noise = torch.randn(real_images.size(0), z_dim, 1, 1).to(device)
        fake_images = generator(noise)
        g_loss = -torch.mean(discriminator(fake_images))

        g_loss.backward()
        optimizer_g.step()

    # Save generated images
    if epoch % 1 == 0:
        generated_images = generator(noise).detach().cpu()
        save_image(generated_images, f"generated_images/generated_epoch_{epoch}.png", nrow=16, normalize=True)

    print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

KeyboardInterrupt: 

In [None]:
# Assuming the DataLoader is set up with num_workers
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Training loop
num_epochs = 10
n_critic = 5  # Number of training steps for the discriminator

for epoch in range(num_epochs):
    for i, real_images in enumerate(loader):
        real_images = real_images.to(device)

        # Train the discriminator
        for _ in range(n_critic):
            optimizer_d.zero_grad()

            # Real images
            real_output = discriminator(real_images)
            d_loss_real = torch.mean(real_output)

            # Fake images
            noise = torch.randn(real_images.size(0), z_dim, 1, 1).to(device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images.detach())
            d_loss_fake = torch.mean(fake_output)

            # Total loss
            d_loss = d_loss_fake - d_loss_real
            d_loss.backward()

            # Weight clipping
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()

        noise = torch.randn(real_images.size(0), z_dim, 1, 1).to(device)
        fake_images = generator(noise)
        g_loss = -torch.mean(discriminator(fake_images))

        g_loss.backward()
        optimizer_g.step()

    # Save generated images periodically
    if epoch % 1 == 0:
        generated_images = generator(noise).detach().cpu()
        save_image(generated_images, f"generated_images/generated_epoch_{epoch}.png", nrow=16, normalize=True)

    print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")