# Import Necessary Libraries

In [7]:
import numpy as np 
import torch.nn as nn
import torch
from torch.utils.data.dataloader import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as T
from PIL import Image

# Prepare the Dataset

In [8]:
transforms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

class DressDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.indices = [i for i, (_, label) in enumerate(dataset) if label == 3]
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        image, label = self.dataset[real_idx]
        return image, label

data_mnist = datasets.FashionMNIST(root='.', train=True, transform=transforms, download=True)
dress_dataset = DressDataset(data_mnist)

batch_size = 32
train_loader = DataLoader(dataset=dress_dataset, batch_size=batch_size, shuffle=True)

# Generator Model

In [9]:
# Generator model
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.gen_model = nn.Sequential(
            nn.Linear(100 + 10, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x, labels):
        label_input = self.label_emb(labels)
        x = torch.cat((x, label_input), dim=1)
        out = self.gen_model(x)
        out = out.view(x.size(0), 1, 28, 28)
        return out

# Discriminator Model

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.flat = nn.Flatten()
        self.dis_model = nn.Sequential(
            nn.Linear(784 + 10, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        label_input = self.label_emb(labels)
        x = self.flat(x)
        x = torch.cat((x, label_input), dim=1)
        out = self.dis_model(x)
        return out

# Initialize models and device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Setting model parameters

In [11]:
lr = 0.0002
epochs = 100
loss_function = nn.BCELoss()
optim_gen = torch.optim.Adam(generator.parameters(), lr=lr)
optim_dis = torch.optim.Adam(discriminator.parameters(), lr=lr)

# Training

In [13]:
# Training loop
for epoch in range(epochs):
    for n, (input_data, labels) in enumerate(train_loader):
        input_data = input_data.to(device)
        labels = labels.to(device)

        # Create ones for real and zeros for fake labels
        real_labels = torch.ones((input_data.size(0), 1), device=device)
        fake_labels = torch.zeros((input_data.size(0), 1), device=device)

        # Train discriminator with real data
        optim_dis.zero_grad()
        discriminator_output_real = discriminator(input_data, labels)
        loss_discriminator_real = loss_function(discriminator_output_real, real_labels)
        loss_discriminator_real.backward()
        optim_dis.step()

        # Train discriminator with generated data
        noise = torch.randn((input_data.size(0), 100)).to(device)
        gen_labels = torch.full((input_data.size(0),), 3, dtype=torch.long).to(device)
        generated_data = generator(noise, gen_labels)

        optim_dis.zero_grad()
        discriminator_output_fake = discriminator(generated_data.detach(), gen_labels)
        loss_discriminator_fake = loss_function(discriminator_output_fake, fake_labels)
        loss_discriminator_fake.backward()
        optim_dis.step()

        # Train generator
        optim_gen.zero_grad()
        generator_output = generator(noise, gen_labels)
        discriminator_output_gen = discriminator(generator_output, gen_labels)
        loss_generator = loss_function(discriminator_output_gen, real_labels)
        loss_generator.backward()
        optim_gen.step()

        # Print losses
        if n == len(train_loader) - 1:
            print(f'Epoch [{epoch+1}/{epochs}], ' +
                  f'Discriminator Loss Real: {loss_discriminator_real.item():.4f}, ' +
                  f'Discriminator Loss Fake: {loss_discriminator_fake.item():.4f}, ' +
                  f'Generator Loss: {loss_generator.item():.4f}')


Epoch [1/100], Discriminator Loss Real: 0.0381, Discriminator Loss Fake: 0.1619, Generator Loss: 8.1867
Epoch [2/100], Discriminator Loss Real: 0.0909, Discriminator Loss Fake: 0.0356, Generator Loss: 6.0008
Epoch [3/100], Discriminator Loss Real: 0.0158, Discriminator Loss Fake: 0.0027, Generator Loss: 6.7155
Epoch [4/100], Discriminator Loss Real: 0.0095, Discriminator Loss Fake: 0.0182, Generator Loss: 5.1138
Epoch [5/100], Discriminator Loss Real: 0.0394, Discriminator Loss Fake: 0.0098, Generator Loss: 4.9831
Epoch [6/100], Discriminator Loss Real: 0.0660, Discriminator Loss Fake: 0.0052, Generator Loss: 6.8698
Epoch [7/100], Discriminator Loss Real: 0.0022, Discriminator Loss Fake: 0.0384, Generator Loss: 5.1510
Epoch [8/100], Discriminator Loss Real: 0.0267, Discriminator Loss Fake: 0.0538, Generator Loss: 3.9255
Epoch [9/100], Discriminator Loss Real: 0.1496, Discriminator Loss Fake: 0.0646, Generator Loss: 3.8339
Epoch [10/100], Discriminator Loss Real: 0.2797, Discriminator L

# Save the model

In [14]:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

# Testing the Model

In [16]:
generator.eval()
batch_size = 1
noise = torch.randn(batch_size, 100, device=device)
gen_labels = torch.full((batch_size,), 3, dtype=torch.long, device=device)
generated_data = generator(noise, gen_labels).squeeze().detach().cpu().numpy()

# Convert to image and save
image_array = (generated_data * 0.5 + 0.5) * 255  # Rescale to [0, 255]
image_array = image_array.astype(np.uint8)
image = Image.fromarray(image_array.squeeze(), mode='L')  # 'L' mode for grayscale
image.save('generated_image.png')