In [None]:
#import block
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt

#define all variables here to use later 
#set to cuda if you have an nvidia GPU. AMD GPUS are not properly supported so just let it to cpu (will take longer to train than gpu)
device = torch.device("cpu") 
#value for one-hot encode. Here it is 2 because that what the docu says binary classification
label_dim = 2
# Dimension of the random noise vector
noise_dim = 100
# Define loss function and optimizers
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
#trainings cycles
epochs = 50
lr = 0.0001
beta1 = 0.5
latent_dim = 100  # Change this to the desired value

In [None]:
class RatioDataset(Dataset):
    def __init__(self, image_dir, num_dogs, num_cats, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Load dog images
        dog_dir = os.path.join(image_dir, "Dog")
        dog_images = os.listdir(dog_dir)[:num_dogs]
        self.image_paths.extend([os.path.join(dog_dir, img) for img in dog_images])
        self.labels.extend([1] * len(dog_images))  # 1 for dogs

        # Load cat images
        cat_dir = os.path.join(image_dir, "Cat")
        cat_images = os.listdir(cat_dir)[:num_cats]
        self.image_paths.extend([os.path.join(cat_dir, img) for img in cat_images])
        self.labels.extend([0] * len(cat_images))  # 0 for cats

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx
        if self.transform:
            image = self.transform(image)
        return image, label


In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),       # Resize to 128x128
    transforms.ToTensor(),               # Convert to tensor
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

In [None]:
image_dir = "archive/PetImages"  # Replace with your dataset's path

# Example: 100 dogs and 100 cats
dataset = RatioDataset(image_dir=image_dir, num_dogs=100, num_cats=10, transform=transform)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, label_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(3*128*128 + label_dim, 1024),  # Input is the concatenated image and label
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),  # Output is a single scalar (real or fake)
            nn.Sigmoid()  # Sigmoid activation to output a probability
        )

    def forward(self, img, labels):
        x = torch.cat([img.view(img.size(0), -1), labels], dim=1)  # Concatenate image and label
        return self.fc(x)  # Output a probability (real or fake)

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 3*128*128),  # Output size for a 128x128 image (3 channels for RGB)
            nn.Tanh()  # Normalize to the range [-1, 1] (for image data)
        )

    def forward(self, z, labels):
        #labels = nn.functional.one_hot(labels, num_classes=label_dim).float()
        x = torch.cat([z, labels], dim=1)  # Concatenate noise vector with label vector
        return self.fc(x).view(-1, 3, 128, 128)  # Reshape to image size (3 channels)

In [None]:
import matplotlib.pyplot as plt
def generate_images(generator, label, num_images=6):
    generator.eval()
    
    # Generate noise and labels
    noise = torch.randn(num_images, noise_dim, device=device)
    one_hot_label = nn.functional.one_hot(torch.tensor([label] * num_images), num_classes=label_dim).float().to(device)
    fake_imgs = generator(noise, one_hot_label).detach().cpu()
    # Determine grid size (e.g., 2 rows and 3 columns for 6 images)
    grid_rows = 2
    grid_cols = 3
    
    # Set up the grid for plotting
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(10, 7))
    axes = axes.flatten()  # Flatten the grid to make indexing easier

    # Plot each image on the grid
    for i in range(num_images):
        img = (fake_imgs[i] + 1) / 2  # Rescale to [0, 1]
        img = img.permute(1, 2, 0).numpy()  # Rearrange to (H, W, C) for plotting
        axes[i].imshow(img)
        axes[i].axis("off")
    
    # Remove any unused axes if the grid is larger than num_images
    for j in range(num_images, len(axes)):
        axes[j].axis("off")

    # Show the grid
    plt.tight_layout()
    plt.show()

In [None]:
# Define the number of epochs
epochs = 150  # Adjust as needed

generator = Generator(noise_dim, label_dim)
discriminator = Discriminator(label_dim)
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(epochs):
    for real_imgs, labels in data_loader:
        batch_size = real_imgs.size(0)
        real_imgs, labels = real_imgs.to(device), labels.to(device)

        # One-hot encode labels
        one_hot_labels = nn.functional.one_hot(labels, num_classes=label_dim).float().to(device)

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim, device=device)
        fake_imgs = generator(noise, one_hot_labels)
        real_preds = discriminator(real_imgs, one_hot_labels)
        fake_preds = discriminator(fake_imgs.detach(), one_hot_labels)
        loss_d = criterion(real_preds, torch.ones_like(real_preds)) + \
                 criterion(fake_preds, torch.zeros_like(fake_preds))

        optimizer_d.zero_grad()
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        fake_preds = discriminator(fake_imgs, one_hot_labels)
        loss_g = criterion(fake_preds, torch.ones_like(fake_preds))

        optimizer_g.zero_grad()
        loss_g.backward()
        optimizer_g.step()

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_d:.4f}, Loss G: {loss_g:.4f}")
    if epoch % 10 == 0:
        generate_images(generator, label=0, num_images=2)
        #generate_images(generator, label=1, num_images=2)

In [None]:
print("Dog Pictures")
# Generate 5 dog images (label=1)
num_images = 6
generate_images(generator, label=1, num_images=6)
print("Cat Pictures")
# Generate 5 cat images (label=0)
generate_images(generator, label=0, num_images=6)