In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from model import Generator, Discriminator, gradient_penalty, wasserstein_loss, gradient_penalty
from torchvision.utils import save_image
from tqdm import tqdm
import os



In [None]:
dataroot = "data/"
save_path = "checkpoints" 
image_save_path = os.path.join(save_path, 'images')
weights_save_path = os.path.join(save_path, 'weights')
os.makedirs(image_save_path, exist_ok=True)
os.makedirs(weights_save_path, exist_ok=True) 
image_size = 96  
batch_size = 256
workers = 8
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
dataset = ImageFolder(root=dataroot, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers,drop_last=True)

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

dataiter = iter(dataloader)
images, _ = next(dataiter)
grid_img = vutils.make_grid(images[:16], nrow=4, padding=2, normalize=True) # 调整nrow值，以适应您的显示
imshow(grid_img)

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)

In [None]:
lr = 0.0001
beta1 = 0.5
beta2 = 0.9

G_opt = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
D_opt = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
epochs = 200

In [None]:
epoch_count = 1
G_losses = []
D_losses = []

for epoch in range(epochs):
    G_loss_accum = 0.0
    D_loss_accum = 0.0
    num_batches = 0

    for batch_index, (img, _) in enumerate(tqdm(dataloader, total=len(dataloader))):
        noise = torch.randn(batch_size, 100, 1, 1).to(device)
        img = img.to(device)

        # Train Discriminator
        D_opt.zero_grad()
        true = D(img)
        fake_img = G(noise)
        fake = D(fake_img)

        # Wasserstein losses
        D_loss_t = wasserstein_loss(true, torch.ones_like(true))
        D_loss_f = wasserstein_loss(fake, -torch.ones_like(fake))

        # Gradient penalty
        gp = gradient_penalty(D, img, fake_img.detach(), device)
        D_loss = D_loss_t + D_loss_f + gp

        D_loss.backward()
        D_opt.step()

        D_loss_accum += D_loss.item()

        # Train Generator
        G_opt.zero_grad()
        fake_img = G(noise)
        fake = D(fake_img)

        # Wasserstein loss
        G_loss = -wasserstein_loss(fake, torch.ones_like(fake))

        G_loss.backward()
        G_opt.step()

        G_loss_accum += G_loss.item()
        num_batches += 1

    if epoch_count % 2 == 0:
        with torch.no_grad():
            fake_img = G(noise).detach().cpu()
        img_grid = vutils.make_grid(fake_img, nrow=8, padding=2, normalize=True)
        vutils.save_image(img_grid, f'{save_path}/fake_epoch_{epoch_count}.jpg')

        torch.save(G.state_dict(), f'{weights_save_path}/G_epoch_{epoch_count}.pth')
        torch.save(D.state_dict(), f'{weights_save_path}/D_epoch_{epoch_count}.pth')

    print(f"Epoch {epoch_count}: G_loss = {G_loss_accum / num_batches:.4f}, D_loss = {D_loss_accum / num_batches:.4f}, Wasserstein Distance = {abs(G_loss_accum + D_loss_accum) / num_batches:.4f}")

    G_losses.append(G_loss_accum / num_batches)
    D_losses.append(D_loss_accum / num_batches)
    epoch_count += 1

# Plot the losses
plt.plot(G_losses, label='Generator loss')
plt.plot(D_losses, label='Discriminator loss')
plt.legend()
plt.xlabel('Epoch')
