In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

## Définition et entraînement d'un GAN de zéro

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


latent_dim = 100
image_size = 28 * 28
batch_size = 128
n_epochs = 50
sample_interval = 500

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dl = DataLoader(dataset=train_ds, shuffle=True, batch_size=batch_size)

In [None]:
def real_data_target(size):
    return torch.ones(size, 1).to(device)

def fake_data_target(size):
    return torch.zeros(size, 1).to(device)

def imgs_to_vec(imgs):
    return imgs.view(imgs.size(0), -1)

def vec_to_imgs(vec):
    return vec.view(vec.size(0), 1, 28, 28)

def noise(size, latent_dim=100):
    return torch.randn(size, latent_dim).to(device)

def display_images(imgs, n_cols=4, figsize=(8, 8)):
    plt.figure(figsize=figsize)
    for i in range(min(len(imgs), 16)):  
        plt.subplot(4, 4, i + 1)
        plt.imshow(imgs[i].cpu().data.numpy().reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()


image_batch = next(iter(dl))[0]
print("Viz:")
display_images(image_batch)


class DiscriminatorNet(nn.Module):
    def __init__(self, img_size=28):
        super(DiscriminatorNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


class GeneratorNet(nn.Module):
    def __init__(self, latent_dim=100, img_size=28):
        super(GeneratorNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


discriminator = DiscriminatorNet().to(device)
generator = GeneratorNet().to(device)


d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


loss_fn = nn.BCELoss()


def train_mnist_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device):
    fixed_noise = noise(16)
    d_losses = []
    g_losses = []

    for epoch in tqdm(range(n_epochs), desc="Training Progress"):
        d.train()
        g.train()
        d_running_loss = 0
        g_running_loss = 0

        for batch_idx, (real_images, _) in enumerate(dl):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            
            d_optim.zero_grad()
            real_images = imgs_to_vec(real_images)
            real_targets = real_data_target(batch_size)
            fake_targets = fake_data_target(batch_size)

            real_loss = loss_fn(d(real_images), real_targets)
            fake_images = g(noise(batch_size)).detach()
            fake_loss = loss_fn(d(fake_images), fake_targets)

            d_loss = real_loss + fake_loss
            d_loss.backward()
            d_optim.step()
            d_running_loss += d_loss.item()

            
            g_optim.zero_grad()
            fake_images = g(noise(batch_size))
            g_loss = loss_fn(d(fake_images), real_targets)
            g_loss.backward()
            g_optim.step()
            g_running_loss += g_loss.item()

            if batch_idx % sample_interval == 0:
                print(f"Epoch [{epoch}/{n_epochs}] Batch {batch_idx}/{len(dl)} \
                      Loss D: {d_loss:.4f}, loss G: {g_loss:.4f}")
                g.eval()
                with torch.no_grad():
                    test_images = vec_to_imgs(g(fixed_noise)).cpu()
                    display_images(test_images)
                g.train()

        d_epoch_loss = d_running_loss / len(dl)
        g_epoch_loss = g_running_loss / len(dl)
        d_losses.append(d_epoch_loss)
        g_losses.append(g_epoch_loss)

    return d_losses, g_losses


d_losses, g_losses = train_mnist_gan(discriminator, generator, d_optimizer, g_optimizer, loss_fn, dl, n_epochs, device)


plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')
plt.legend()
plt.show()


torch.save(generator.state_dict(), "gan_generator_pytorch.pth")


def generate_and_plot_image(generator, device, latent_dim=100):
    with torch.no_grad():
        z = torch.randn(1, latent_dim).to(device)
        gen_img = generator(z).cpu().view(28, 28).numpy()
        gen_img = 0.5 * gen_img + 0.5
        plt.imshow(gen_img, cmap='gray')
        plt.axis('off')
        plt.show()


generate_and_plot_image(generator, device)


## Inspection des résultats

In [None]:
import cv2
import numpy as np

def analyze_strokes(image):
    # Convert the image to grayscale (if not already)
    if len(image.shape) == 3 and image.shape[2] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Convert to 8-bit pixel values
    image = np.uint8(image * 255)

    # Apply a Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(image, (5, 5), 0)

    # Use Canny edge detection to find edges in the image
    edges = cv2.Canny(blurred, 50, 150)

    # Find contours in the edged image
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Analyze the contours to determine stroke features
    stroke_features = []
    for contour in contours:
        # Calculate the bounding box of the contour
        x, y, w, h = cv2.boundingRect(contour)
        aspect_ratio = w / float(h)
        extent = cv2.contourArea(contour) / float(w * h)
        stroke_features.append((aspect_ratio, extent))

    return stroke_features

# Function to display and analyze generated images
def inspect_generated_images(generator, device, latent_dim=100, num_images=5):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)
        gen_imgs = generator(z).cpu().view(-1, 28, 28).numpy()

    plt.figure(figsize=(10, 5))
    for i, img in enumerate(gen_imgs):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
        # Analyze stroke features
        features = analyze_strokes(img)
        print(f"Image {i+1} Stroke Features: {features}")
    plt.show()

# Inspect generated images
inspect_generated_images(generator, device)