In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
dataroot = '/kaggle/input/deepfashion-1/datasets/train_images'


In [None]:
import os

image_files = os.listdir(dataroot)

image_files.sort()

print(image_files[-5:])


In [None]:
import cv2

image_list = os.listdir(dataroot)

num_images = 5

plt.figure(figsize=(15, 5))

for i in range(num_images):
    img_path = os.path.join(dataroot, image_list[i])  # Construct full image path
    img = cv2.imread(img_path)  # Read the image
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for proper color display
    plt.subplot(1, num_images, i + 1)  # Create subplots
    plt.imshow(img)  # Display the image
    plt.axis('off')  # Hide the axis
    plt.title(image_list[i])  # Set the title to the image file name

plt.tight_layout()  # Adjust the layout
plt.show()  # Show the plot

In [None]:
dataroot_men = []
dataroot_women = []

for file_name in image_files:
    if file_name.startswith('MEN'):
        dataroot_men.append(file_name)
    elif file_name.startswith('WOMEN'):
        dataroot_women.append(file_name)

print("MEN images:", dataroot_men[:5])
print("WOMEN images:", dataroot_women[:5])


In [None]:
x=(len(dataroot_men))
y=(len(dataroot_women))
print(x+y)
print(len(image_files))

In [None]:
print(x)
print(y)

In [None]:
import shutil

os.makedirs('/kaggle/working/data_split/MEN', exist_ok=True)
os.makedirs('/kaggle/working/data_split/WOMEN', exist_ok=True)

for file_name in dataroot_men:
    src_path = os.path.join(dataroot, file_name)  # Original path
    dest_path = os.path.join('/kaggle/working/data_split/MEN', file_name)  # Destination path
    shutil.copy(src_path, dest_path)  # You can use shutil.move to move instead of copy

for file_name in dataroot_women:
    src_path = os.path.join(dataroot, file_name)
    dest_path = os.path.join('/kaggle/working/data_split/WOMEN', file_name)
    shutil.copy(src_path, dest_path)


In [None]:
workers = 4
batch_size = 128
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 5
lr = 0.0002
beta1 = 0.5
ngpu = 1

In [None]:
import torchvision.datasets as dset
import torchvision.transforms as transforms

data_split_root = '/kaggle/working/data_split'

# Use ImageFolder to load the dataset with transformations
dataset = dset.ImageFolder(root=data_split_root,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ]))

# Create DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Check some images from the dataset
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1,2,0)))





In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def build_generator(ngpu):
    layers = []

    layers.append(nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False))
    layers.append(nn.BatchNorm2d(ngf * 8))
    layers.append(nn.ReLU(True))

    layers.append(nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False))
    layers.append(nn.BatchNorm2d(ngf * 4))
    layers.append(nn.ReLU(True))

    layers.append(nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False))
    layers.append(nn.BatchNorm2d(ngf * 2))
    layers.append(nn.ReLU(True))

    layers.append(nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False))
    layers.append(nn.BatchNorm2d(ngf))
    layers.append(nn.ReLU(True))

    layers.append(nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False))
    layers.append(nn.Tanh())

    model = nn.Sequential(*layers)

    return model

generator = build_generator(ngpu)

generator

In [None]:
netG = build_generator(ngpu).to(device)

if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

netG.apply(weights_init)

netG

In [None]:
def build_discriminator(ngpu):
    layers = []

    layers.append(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))  # Input is (nc) x 64 x 64
    layers.append(nn.LeakyReLU(0.2, inplace=True))

    layers.append(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False))  # State size (ndf) x 32 x 32
    layers.append(nn.BatchNorm2d(ndf * 2))
    layers.append(nn.LeakyReLU(0.2, inplace=True))

    layers.append(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False))  # State size (ndf*2) x 16 x 16
    layers.append(nn.BatchNorm2d(ndf * 4))
    layers.append(nn.LeakyReLU(0.2, inplace=True))

    layers.append(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False))  # State size (ndf*4) x 8 x 8
    layers.append(nn.BatchNorm2d(ndf * 8))
    layers.append(nn.LeakyReLU(0.2, inplace=True))

    layers.append(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False))  # State size (ndf*8) x 4 x 4
    layers.append(nn.Sigmoid())  # Output layer

    model = nn.Sequential(*layers)

    return model

discriminator = build_discriminator(ngpu)

discriminator

In [None]:
netD = build_discriminator(ngpu).to(device)

if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

netD.apply(weights_init)
netD

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.001, betas=(beta1, 0.999))

In [None]:
# import torchvision.datasets as dset
# import torchvision.transforms as transforms

# men_data_root = '/kaggle/working/data_split/MEN'

# men_dataset = dset.ImageFolder(root=men_data_root,
#                                transform=transforms.Compose([
#                                    transforms.Resize(image_size),
#                                    transforms.CenterCrop(image_size),
#                                    transforms.ToTensor(),
#                                    transforms.Normalize((0.5,), (0.5,)),
#                                ]))

# # Create DataLoader for MEN images
# men_dataloader = torch.utils.data.DataLoader(men_dataset, batch_size=batch_size,
#                                              shuffle=True, num_workers=workers)


In [None]:
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 35

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):

        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).float()  # Convert to Float
        output = netD(real_cpu).view(-1) # 256??
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(80,80))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
torch.save({
    'epoch': epoch,
    'modelG_state_dict': netG.state_dict(),
    'modelD_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses,
    'iters': iters,
}, 'gan_checkpoint.pth')

print("Model and optimizer states saved!")


In [None]:
# Load the saved model and optimizer states
checkpoint = torch.load('gan_checkpoint.pth')
netG.load_state_dict(checkpoint['modelG_state_dict'])
netD.load_state_dict(checkpoint['modelD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Resume training settings
G_losses = checkpoint['G_losses']
D_losses = checkpoint['D_losses']
iters = checkpoint['iters']
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

# Print to confirm loading
print("Model and optimizer states loaded!")
print(f"Resuming training from epoch {start_epoch}...")


In [None]:
num_epochs_to_continue = 20 

for epoch in range(start_epoch, start_epoch + num_epochs_to_continue):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).float()  # Convert to Float
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs_to_continue + start_epoch, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == start_epoch + num_epochs_to_continue - 1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


In [None]:
import matplotlib.animation as animation

def visualize_training(img_list, dataloader, device):
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    
    ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    display(HTML(ani.to_jshtml()))
    
    real_batch = next(iter(dataloader))
    
    plt.figure(figsize=(80, 80))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))
    
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    if img_list:
        plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    
    plt.show()

visualize_training(img_list, dataloader, device)

In [None]:
torch.save({
    'epoch': epoch,
    'modelG_state_dict': netG.state_dict(),
    'modelD_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses,
    'iters': iters,
}, 'gan_checkpoint.pth')

print("Model and optimizer states saved!")

# Load the saved model and optimizer states
checkpoint = torch.load('gan_checkpoint.pth')
netG.load_state_dict(checkpoint['modelG_state_dict'])
netD.load_state_dict(checkpoint['modelD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Resume training settings
G_losses = checkpoint['G_losses']
D_losses = checkpoint['D_losses']
iters = checkpoint['iters']
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

# Print to confirm loading
print("Model and optimizer states loaded!")
print(f"Resuming training from epoch {start_epoch}...")


In [None]:
num_epochs_to_continue = 20  # Adjust this to the number of epochs you want to continue training

for epoch in range(start_epoch, start_epoch + num_epochs_to_continue):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).float()  # Convert to Float
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs_to_continue + start_epoch, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == start_epoch + num_epochs_to_continue - 1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


In [None]:
visualize_training(img_list, dataloader, device)

In [None]:
torch.save({
    'epoch': epoch,
    'modelG_state_dict': netG.state_dict(),
    'modelD_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses,
    'iters': iters,
}, 'gan_checkpoint.pth')

print("Model and optimizer states saved!")

# Load the saved model and optimizer states
checkpoint = torch.load('gan_checkpoint.pth')
netG.load_state_dict(checkpoint['modelG_state_dict'])
netD.load_state_dict(checkpoint['modelD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Resume training settings
G_losses = checkpoint['G_losses']
D_losses = checkpoint['D_losses']
iters = checkpoint['iters']
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

# Print to confirm loading
print("Model and optimizer states loaded!")
print(f"Resuming training from epoch {start_epoch}...")


In [None]:
num_epochs_to_continue = 25  # Adjust this to the number of epochs you want to continue training

for epoch in range(start_epoch, start_epoch + num_epochs_to_continue):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).float()  # Convert to Float
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs_to_continue + start_epoch, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == start_epoch + num_epochs_to_continue - 1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


In [None]:
visualize_training(img_list, dataloader, device)

In [None]:
torch.save({
    'epoch': epoch,
    'modelG_state_dict': netG.state_dict(),
    'modelD_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses,
    'iters': iters,
}, 'gan_checkpoint.pth')

print("Model and optimizer states saved!")

# Load the saved model and optimizer states
checkpoint = torch.load('gan_checkpoint.pth')
netG.load_state_dict(checkpoint['modelG_state_dict'])
netD.load_state_dict(checkpoint['modelD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Resume training settings
G_losses = checkpoint['G_losses']
D_losses = checkpoint['D_losses']
iters = checkpoint['iters']
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

# Print to confirm loading
print("Model and optimizer states loaded!")
print(f"Resuming training from epoch {start_epoch}...")


In [None]:
num_epochs_to_continue = 25  # Adjust this to the number of epochs you want to continue training

for epoch in range(start_epoch, start_epoch + num_epochs_to_continue):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).float()  # Convert to Float
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs_to_continue + start_epoch, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == start_epoch + num_epochs_to_continue - 1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


In [None]:
visualize_training(img_list, dataloader, device)

In [None]:
torch.save({
    'epoch': epoch,
    'modelG_state_dict': netG.state_dict(),
    'modelD_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses,
    'iters': iters,
}, 'gan_checkpoint.pth')

print("Model and optimizer states saved!")

# Load the saved model and optimizer states
checkpoint = torch.load('gan_checkpoint.pth')
netG.load_state_dict(checkpoint['modelG_state_dict'])
netD.load_state_dict(checkpoint['modelD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])

# Resume training settings
G_losses = checkpoint['G_losses']
D_losses = checkpoint['D_losses']
iters = checkpoint['iters']
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

# Print to confirm loading
print("Model and optimizer states loaded!")
print(f"Resuming training from epoch {start_epoch}...")


In [None]:
# Define a save path
save_path = './gan_model_epoch_{}.pth'  # You can change the path and filename as needed

# At the end of each epoch
torch.save({
    'epoch': epoch,
    'generator_state_dict': netG.state_dict(),
    'discriminator_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'G_losses': G_losses,
    'D_losses': D_losses
}, save_path.format(epoch))

print(f'Model saved at epoch {epoch} to {save_path.format(epoch)}')


_______________

In [None]:
import torch
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

# Load your trained GAN model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the model checkpoint
checkpoint = torch.load('./gan_model_epoch_99.pth', map_location=device)

# Build the generator model (Ensure you have defined the build_generator function)
netG = build_generator(ngpu)  # Adjust ngpu according to your configuration
netG.load_state_dict(checkpoint['generator_state_dict'], strict=False)
netG.eval()

# Define image size based on your model
image_size = 64

# Function to preprocess input clothing images
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# Function to generate outfits
def generate_outfit(user_inputs):
    processed_images = [preprocess_image(img) for img in user_inputs]
    
    # Create a latent vector by averaging the input images
    latent_vector = torch.mean(torch.cat(processed_images), dim=0)

    # Generate a new outfit using the GAN
    with torch.no_grad():
        generated_outfit = netG(latent_vector.unsqueeze(0))  # Add batch dimension
    
    return generated_outfit

# Function to visualize and save the generated outfit
def visualize_and_save_outfit(outfit_tensor, save_path='generated_outfit.png'):
    # Save the generated image
    save_image(outfit_tensor, save_path, normalize=True)
    
    # Visualize the generated outfit
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Generated Outfit")
    plt.imshow(np.transpose(outfit_tensor.squeeze().cpu().numpy(), (1, 2, 0)))
    plt.show()

# Main function to run the code
def main():
    user_clothes = []

    # For Kaggle or Colab, you can upload files from your device
    uploaded = files.upload()  # This will prompt you to upload images from your device
    
    for img_name in uploaded.keys():
        user_clothes.append(img_name)

    if user_clothes:
        outfit_image = generate_outfit(user_clothes)
        visualize_and_save_outfit(outfit_image)

if __name__ == "__main__":
    main()


In [None]:
user_inputs = []
num_items = int(input("How many clothing items do you want to input? "))

for _ in range(num_items):
    img_path = input("Enter the path of the clothing item image: ")
    user_inputs.append(img_path)

generated_outfit = generate_outfit(user_inputs)
# Proceed to visualize or save the generated outfit
