In [None]:
#neural network libraries
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image

#Common libraries
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np 
import os 
import cv2

if torch.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'
print(device)

In [None]:
OPENSLIDE_PATH = r'your/openslide/bin/repository'

if hasattr(os, 'add_dll_directory'):
    # Windows
    with os.add_dll_directory(OPENSLIDE_PATH):
        import openslide
else:
    import openslide

In [None]:
def preprocess_folder(image: None, threshold=0.1)->bool:

    image = np.array(image)
    image = image.astype(np.float32)
    image/=255.

    refrence = np.ones_like(image, dtype = np.float32)*250
    refrence/=255.
    diff = np.abs(image - refrence)
    average_diff = np.mean(diff)
    
    return average_diff > threshold

## Create your Dataset with openslide.

Pyramidal images are a commonly used technique in histopathology to represent images at different resolutions. They are constructed by dividing a high-resolution image into a series of sub-images at decreasing resolutions, organized in a pyramid shape. Each level of the pyramid represents a version of the original image at a different resolution, ranging from very high resolution to very low resolution.

**Utility of Pyramidal Images in Histopathology**

Pyramidal images are extremely useful in histopathology for several reasons:

**Multi-scale Analysis:** Pathologists can examine fine details of tissues at high resolution while having an overview of the tissue at lower resolution. This allows for a comprehensive and detailed analysis of histological samples.

**Memory and Performance Optimization:** By storing images at different resolutions, computational resources can be optimized by loading only relevant parts of the image based on the required zoom level. This enables efficient handling of images even for large samples.

**Extraction of Image Patches:** Pyramidal images facilitate the extraction of image patches at different resolutions. This allows for localized analysis of specific regions of histological samples, which is crucial for many image processing tasks in histopathology.
tile_image Function for Creating Image Patches

The tile_image function allows for the creation of image patches from pyramidal images with openslide. By subdividing the pyramidal image into patches, this function facilitates data preparation for training machine learning models in histopathology. The resulting patches can be used to train deep neural networks to detect specific features in histological images, thereby contributing to the automation of histopathological analyses.


In [None]:
def tile_image(image_path: str, tile_size: int, output_folder: str, preprocess: bool = False) -> None:
    """
    Tiles an image into smaller sub-images and saves each sub-image as a separate file.

    Args:
        image_path (str): Path to the input image.
        tile_size (int): Size of each tile (both width and height).
        output_folder (str): Directory where the tiled images will be saved.
        preprocess (bool, optional): Whether to preprocess the tiles before saving. Defaults to False.

    Returns:
        None
    """
    # Open the image using OpenSlide
    slide = openslide.OpenSlide(image_path)

    # Get the dimensions of the image
    dimensions = slide.dimensions

    # Calculate the number of tiles in each direction
    num_tiles_x = int(np.ceil(dimensions[0] / tile_size))
    num_tiles_y = int(np.ceil(dimensions[1] / tile_size))

    # Slice the image into tiles and save each tile
    pbar = tqdm(range(num_tiles_x))
    for i in pbar:
        for j in range(num_tiles_y):
            # Get the coordinates of the tile
            x = i * tile_size
            y = j * tile_size

            # Calculate the size of the tile
            w = min(tile_size, dimensions[0] - x)
            h = min(tile_size, dimensions[1] - y)

            # Read the tile region
            tile = slide.read_region((x, y), 0, (w, h))

            if preprocess:
                # Preprocess the tile (e.g., apply filters, enhance contrast, etc.)
                valid_image = preprocess_folder(tile)

                if valid_image:
                    # Save the preprocessed tile to the output folder
                    tile_path = os.path.join(output_folder, f"tile_{x}_{y}.tiff")
                    tile.save(tile_path)
            else:
                # Save the tile to the output folder
                tile_path = os.path.join(output_folder, f"tile_{x}_{y}.tiff")
                tile.save(tile_path)


## Create your dataset with small images.

Some images can be read with numpy and theses fuctions can help you patch them.

In [None]:
def patching(image: np.array, row: int, col: int, verbose: int = 1)->list:
    """
    Divide a image into patches.

    Args:
        img(np.array): image.
        row(int): Number of rows the image will be sliced.
        col(int): Number of columns the image will be sliced.
        verbose(int): set verbose to '0' if you wish not to have the plot.
    Return:
        patch(list): Retrun a list of the patches extracted from the image.
    """

    vpatch = np.vsplit(image,row)

    hpatch = []
    patch = []

    for index, vimg in enumerate(vpatch):
        hpatch.append(np.hsplit(vimg, col))
        len(hpatch)

        for himg in hpatch[index]:
            patch.append(himg)
    if verbose == 1:
        for index in range (1, len(patch)+1):

            plt.subplot(row, col, index)
            plt.axis('off')
            plt.imshow(patch[index-1], cmap = 'gray')
        plt.show()
        print(("\n\n"))
    return np.array(patch)

def resize_to_nearest_multiple(image: np.array, target_size=224)-> np.array:
    """
    Resize an image to the nearest multiple of a specified target size.

    Args:
        image (numpy.ndarray): Input image represented as a NumPy array.
        target_size (int, optional): Target size for the resizing. Defaults to 224.

    Returns:
        numpy.ndarray: Resized image.
    """
    width, height, _ = image.shape

    # Calculate the new width and height to be multiples of the target_size
    new_width = int(np.ceil(width / target_size) * target_size)
    new_height = int(np.ceil(height / target_size) * target_size)

    # Resize the image using OpenCV
    resized_img = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)

    return resized_img

def patch_numpy_images(img_path: str, output_folder: str, target_size:int = 224, select_tiles: float = 0.4)->None:
    """
    This function will create patches images of the target size.
    Args: 
        img_path(str): path to the images.
        output_folder(str): path to the folder where the patches will be stored.
        target_size(int): size of the output tiles.
        select_tiles(float): use the preprocess folder to remove the white images  
                    coresponding to the background of the slide.
    Return:
        None.
    """
    target_size = 224
    output_folder = f"{output_folder}/{os.path.splitext(os.path.basename(img_path))[0]}_{target_size}"

    if os.path.exists(output_folder):
        print(f"The folder already exists")
    else: 
        print("Folder doesn't exist, tiling in progress")

        slide = cv2.imread(img_path)
        slide = resize_to_nearest_multiple(slide, target_size= target_size)
        rows, cols = slide.shape[0]//target_size, slide.shape[1]//target_size
        patches = patching(slide, rows, cols)
        os.makedirs(output_folder, exist_ok= True)

        x,y = 0,0
        row, col = 0,0
        for tile in tqdm(patches): 
            if preprocess_folder(tile, select_tiles):
                tile_path = os.path.join(output_folder, f"tile_{x}_{y}.tiff")
                cv2.imwrite(tile_path, tile)

                x = target_size*row
                y = target_size*col

                col+=1
                if col%cols == 0:
                    row+=1
                    col = 0

## Uncomment for a use case
# img_path = r'Path/to/your/image'
# patch_numpy_images(img_path, 'tiles', 128)

This function is used to plot or save images from torch.Tensor.

In [None]:
def show_tensor_images(image_tensor: torch.Tensor, num_images: int = 16, nrow: int = 4, show: bool = True, output_path: str = None)-> None:
    """
    Function for visualizing images: Given a tensor of images, number of images,
    and desired grid layout, plots and displays the images in a uniform grid.

    Args:
        image_tensor (torch.Tensor): The input tensor containing images.
        num_images (int, optional): Number of images to display. Defaults to 16.
        nrow (int, optional): Number of images per row in the grid. Defaults to 4.
        show (bool, optional): Whether to display the plot. Defaults to True.

    Returns:
        None
    """
    # Normalize pixel values to [0, 1]
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()

    # Create a grid of images
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)

    # Display the grid
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')  # Hide axes
    if show:
        plt.show()
    if output_path != None:
        save_image(image_grid, output_path)

## Model:

In [None]:
class Deconv(nn.Module):
    def __init__(self, 
                 in_channels: int = 100, out_channels: int = 64, 
                 act: nn = None, 
                 kernel_size: int = 3, 
                 stride: int = 2, 
                 dropout: float = 0.4, 
                 padding: int = 1,
                 final_layer: bool = False):
        """
        Deconvolutional layer for a generator block in a DCGAN architecture.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            act (nn.Module): Activation function.
            kernel_size (int): Size of the convolutional filter.
            stride (int): Stride of the convolution.
            dropout (float): Dropout probability.
            padding (int): Padding size.
            final_layer (bool): True if it is the final layer, False otherwise.
        """
        
        super(Deconv, self).__init__()
        self.conv_block = self.block(in_channels, out_channels, act, kernel_size, stride, dropout, padding, final_layer)


    def block(self, ic: int, oc: int,
            act: nn,
            kernel_size: int, 
            stride: int , 
            dropout: float, 
            padding: int,
            final_layer: bool):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Args:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            padding (int): Padding size.
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        
        Return:
            block(nn.Sequential()): return a sequence of operations corresponding to a generator block of DCGAN;
        '''
        if not final_layer:
        
            block = nn.Sequential(
                    nn.ConvTranspose2d(ic, oc, kernel_size, stride, padding),
                    nn.BatchNorm2d(oc),
                    nn.Dropout(dropout),
                    act
                    )
        
        else: 
            block = nn.Sequential(
                    nn.ConvTranspose2d(ic, oc, kernel_size, stride, padding),
                    act
                    )
        
        return block
    

    def forward(self, x):

        x = self.conv_block(x)

        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, zdim: int = 100, output_dim: int = 256, num_upsample: int = 3, img_ch: int = 3, activation: nn = nn.ReLU(inplace = True)):

        """
        Generator network for generating images in a DCGAN architecture.

        Args:
            zdim (int): Dimension of the input noise vector.
            output_dim (int): Dimension of the output feature representation.
            img_ch (int): Number of channels in the output image.
            activation (nn.Module): Activation function.
             num_upsample(int): Number of upsampling block, depending on the size of the output image.
                            The image size start from 8 and is multiplied by two for each upsampling block?
        """

        super().__init__()
        self.output_dim = output_dim
        self.modules_dict = nn.ModuleDict()
        self.dim= 8 
        self.modules_dict["First_layer"] = nn.Linear(zdim, (output_dim)*self.dim*self.dim)

        for index in range(num_upsample):
            if index==num_upsample-1:
                self.modules_dict["last_layer"] = Deconv(output_dim, img_ch, kernel_size = 4, act = nn.Tanh(), stride = 2, final_layer= True) 

            else: 
                self.modules_dict[f"Deconv_{index}"] = Deconv(output_dim, output_dim//2, act = activation, kernel_size = 4, stride = 2) 
            output_dim = output_dim//2
    
    def forward(self, x):
        for name, modules in self.modules_dict.items():
            x = modules(x)
            if name =='First_layer' :
                x = x.view(-1, self.output_dim, self.dim, self.dim)
        return x



In [None]:
class Discriminator(nn.Module):

    def __init__(self, output_channels: int = 64, img_channels: int = 1):
        super(Discriminator, self).__init__()

        self.modules_dict = nn.ModuleDict()
        self.modules_dict["Block_0"] = self.block(img_channels, output_channels)
        self.modules_dict["Block_1"] = self.block(output_channels, output_channels*2)
        self.modules_dict["Block_2"] = self.block(output_channels*2, 1, final_layer = True)

    def block(self, 
            input_channels: int, 
            output_channels: int, 
            kernel_size: int = 4, 
            stride: int = 2, 
            final_layer: bool = False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of the DCGAN; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Args:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        Return:
            block(nn.Sequential()): return a sequence of operations corresponding to a generator block of DCGAN;
        '''
        if not final_layer:
            block =  nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            block =  nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )
        
        return block

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        for key, modules in self.modules_dict.items():
            image = modules(image)
        image = image.view(len(image), -1)
        return image

In [None]:
from dataclasses import dataclass

@dataclass()
class GANConfigs:
    """
    Configuration class for a Generative Adversarial Network (GAN).

    Attributes:
        generator (nn.Module): Generator model.
        discriminator (nn.Module): Discriminator model.
        image_size (int): Size of the input images.
        batch_size (int): Batch size.
        num_epochs (int): Total number of training epochs.
        current_epoch (int): Current epoch number.
        display_metrics (int): Interval for displaying metrics during training.
        z_dim (int): Dimension of the input noise vector.
        gen_lr (float): Learning rate for the generator.
        disc_lr (float): Learning rate for the discriminator.
        with_lr_scheduler (bool): Whether to use learning rate schedulers.
        gen_lr_step_size (int): Step size for the generator's learning rate scheduler.
        disc_lr_step_size (int): Step size for the discriminator's learning rate scheduler.
        gen_lr_gamma (float): Gamma value for the generator's learning rate scheduler.
        disc_lr_gamma (float): Gamma value for the discriminator's learning rate scheduler.
        output_dir (str): Directory to save outputs.
        curr_epoch (int): Current epoch number (deprecated, use current_epoch instead).
        logs (dict): Dictionary to store training logs.
        device (str): Device to perform computations ('cpu' or 'cuda').
    """
    
    generator: nn.Module 
    discriminator: nn.Module 

    image_size: int = 128
    batch_size: int = 32
    
    num_epochs: int = 100
    current_epoch: int = 0
    
    display_metrics: int = 10

    z_dim: int = 100

    gen_lr: float = 1e-4 
    disc_lr: float = 1e-4

    with_lr_scheduler: bool = True

    gen_lr_step_size: int = 10
    disc_lr_step_size: int = 10

    gen_lr_gamma: float = 0.1
    disc_lr_gamma: float = 0.1

    output_dir: str = None
    curr_epoch: int = 0
    logs: dict = None
    device: str = 'cpu'

    def __post_init__(self):
        
        """
        Initializes optimizer and learning rate scheduler based on the provided configurations.
        """

        self.gen_optim = torch.optim.Adam(self.generator.parameters(), lr = self.gen_lr)
        self.disc_optim = torch.optim.Adam(self.discriminator.parameters(), lr = self.disc_lr)

        if self.with_lr_scheduler:
            self.gen_lr_schedular = torch.optim.lr_scheduler.StepLR(self.gen_optim, step_size=self.gen_lr_step_size, gamma=self.gen_lr_gamma)
            self.disc_lr_schedular = torch.optim.lr_scheduler.StepLR(self.disc_optim, step_size=self.disc_lr_step_size, gamma=self.disc_lr_gamma)

        if self.output_dir != None:
            os.makedirs(self.output_dir, exist_ok= True)


In [None]:
class WLoss(GANConfigs):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def get_gradient(self, real: torch.Tensor, fake: torch.Tensor, epsilon: torch.Tensor):
        '''
        Return the gradient of the critic's scores with respect to mixes of real and fake images.
        
        Args:
            real (torch.Tensor): A batch of real images.
            fake (torch.Tensor): A batch of fake images.
            epsilon (torch.Tensor): A vector of the uniformly random proportions of real/fake per mixed image.
            
        Returns:
            torch.Tensor: The gradient of the critic's scores, with respect to the mixed image.
        '''
        # Mix the images together
        #print(real.shape, fake.shape, epsilon.shape)
        mixed_images = real * epsilon + fake * (1 - epsilon)

        # Calculate the critic's scores on the mixed images
        mixed_scores = self.discriminator(mixed_images)
        
        # Take the gradient of the scores with respect to the images
        gradient = torch.autograd.grad(

            inputs=mixed_images,
            outputs=mixed_scores,
            # These other parameters have to do with the pytorch autograd engine works
            grad_outputs=torch.ones_like(mixed_scores), 
            create_graph=True,
            retain_graph=True,
            )[0]
        
        return gradient

    def gradient_penalty(self, gradient: torch.Tensor) -> torch.Tensor:
        '''
        Return the gradient penalty, given a gradient.
        
        Given a batch of image gradients, this function calculates the magnitude of each image's gradient
        and penalizes the mean quadratic distance of each magnitude to 1.
        
        Args:
            gradient (torch.Tensor): The gradient of the critic's scores, with respect to the mixed image.
            
        Returns:
            torch.Tensor: The gradient penalty.
        '''
        # Flatten the gradients so that each row captures one image
        gradient = gradient.view(len(gradient), -1)

        # Calculate the magnitude of every row
        gradient_norm = gradient.norm(2, dim=1)
        
        # Penalize the mean squared distance of the gradient norms from 1
        penalty = torch.mean(torch.square(gradient_norm-1))
        return penalty
    

    def w_loss_gen(self, crit_fake_pred: torch.Tensor):
        '''
        Return the loss of a generator given the critic's scores of the generator's fake images.
        Args:
            crit_fake_pred: the critic's scores of the fake images
        Returns:
            gen_loss: a scalar loss value for the current batch of the generator
        '''
        gen_loss = -torch.mean(crit_fake_pred)
        return gen_loss
    

    def w_loss_crit(self, crit_fake_pred: torch.Tensor, crit_real_pred: torch.Tensor, gp: torch.Tensor, c_lambda: int)->torch.Tensor:
        '''
        Return the loss of a critic given the critic's scores for fake and real images,
        the gradient penalty, and gradient penalty weight.
        Parameters:
            crit_fake_pred(torch.Tensor): the critic's scores of the fake images
            crit_real_pred(torch.Tensor): the critic's scores of the real images
            gp(torch.Tensor): the unweighted gradient penalty
            c_lambda(int): the current weight of the gradient penalty 
        Returns:
            crit_loss(torch.Tensor): a scalar for the critic's loss, accounting for the relevant factors
        '''

        wasserstein_estimate = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred)
        
        # Compute the gradient penalty term
        crit_loss = wasserstein_estimate + c_lambda * gp

        return crit_loss
    
    def update_critic(self, 
                    real: torch.tensor = None, 
                    critic_repeats: int = 5, 
                    c_lambda: int = 10,
                    )->int:
        """
        Update the critic (discriminator) network for a specified number of iterations.

        Args:
            self (object): The instance of the class containing the generator, critic, and optimizer.
            real (torch.tensor, optional): The real input data for the critic. Defaults to None.
            critic_repeats (int, optional): The number of iterations to update the critic. Defaults to 5.
            c_lambda (int, optional): The coefficient for the gradient penalty term. Defaults to 10.

        Returns:
            list: A list containing the average critic loss for this iteration.
        """
        n_samples = len(real)
        mean_iteration_critic_loss = 0
        for _ in range(critic_repeats):

            self.disc_optim.zero_grad()
            fake_noise = torch.randn(n_samples, self.z_dim, device = self.device)
            fake = self.generator(fake_noise)

            crit_fake_pred = self.discriminator(fake.detach())
            crit_real_pred = self.discriminator(real)

            epsilon = torch.rand(n_samples, 1, 1, 1, device=self.device, requires_grad=True)
            gradient = self.get_gradient(real, fake.detach(), epsilon)
            gp = self.gradient_penalty(gradient)
            crit_loss = self.w_loss_crit(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / critic_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            self.disc_optim.step()
        
        return [mean_iteration_critic_loss]
    
    def update_w_gen(self, n_samples: int = None) -> torch.Tensor:
        '''
        Update the generator network parameters based on the current critic's scores of fake images.
        
        Args:
            n_samples (int): Number of samples to generate and update with.
            
        Returns:
            torch.Tensor: A list containing the generator loss for the current batch.
        '''
        self.gen_optim.zero_grad()
        fake_noise = torch.randn(n_samples, self.z_dim, device=self.device)
        fake = self.generator(fake_noise)
        crit_fake_pred = self.discriminator(fake)
        
        gen_loss = self.w_loss_gen(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        self.gen_optim.step()

        # Keep track of the average generator loss
        return [gen_loss.item()]

    def show_metrics(self, real: torch.tensor):
        '''
        Show training metrics such as generator and discriminator losses and visualize generated images.
        
        Args:
            real (torch.Tensor): A batch of real images.
        '''
        
        plt.figure(figsize=(10, 6))

        self.gen_optim.zero_grad()
        fake_noise = torch.randn(8, self.z_dim, device=self.device)
        fake = self.generator(fake_noise)
        image_path = f'epoch_{self.curr_epoch}.jpg'
        output_image = os.path.join(self.output_dir,image_path)
        show_tensor_images(fake, num_images = 8, output_path = output_image)
        show_tensor_images(real, num_images = 8)
        
        plt.plot(self.logs["gen_loss_on_epoch"], label="Generator Loss", color="blue")
        plt.plot(self.logs["disc_loss_on_epoch"], label="Discriminator Loss", color="red")
        plt.xlabel(f"epoch: {self.curr_epoch}")
        plt.ylabel("Loss")
        plt.title(f"Losses")
        plt.legend()
        plt.show()            

In [None]:
def train_WGAN(configs: None= None, dataloader: torch.utils.data = None, c_lambda = 10)->None:       
    generator_losses = []
    discriminator_losses = []

    gen_loss_on_epoch = []
    disc_loss_on_epoch = []
    for epoch in range(configs.num_epochs):

        # Dataloader returns the batches of real images
        pbar = tqdm(total = len(dataloader))
        pbar.set_description(f"Epoch: {epoch}")

        for real, _ in dataloader:

            cur_batch_size = len(real)
            # Flatten the batch of real images from the dataset
            real = real.to(configs.device)

            discriminator_losses += configs.update_critic(real, c_lambda = c_lambda)
            generator_losses += configs.update_w_gen(n_samples = cur_batch_size)

            # Store generator loss
            generator_losses.append(generator_losses[-1])
            discriminator_losses.append(discriminator_losses[-1])
            configs.logs = {
                    "gen_loss": generator_losses[-1], 
                    "disc_loss": discriminator_losses[-1],
                    "gen_lr": configs.gen_lr_schedular.get_last_lr()[0], 
                    "disc_lr": configs.disc_lr_schedular.get_last_lr()[0],
                    }
            pbar.update(1)
            pbar.set_postfix(**configs.logs)
        
        gen_loss_on_epoch.append(generator_losses[-1])
        disc_loss_on_epoch.append(discriminator_losses[-1])

        configs.logs["gen_loss_on_epoch"] = gen_loss_on_epoch
        configs.logs["disc_loss_on_epoch"] = disc_loss_on_epoch
        
        if (epoch % configs.display_metrics == 0) or (epoch == configs.num_epochs-1):
            configs.show_metrics(real)

            # Save models each display_metrics
            torch.save(configs.generator.state_dict(), "gen_model.pth")
            torch.save(configs.discriminator.state_dict(), "disc_model.pth")
        
        if configs.with_lr_scheduler:

            configs.gen_lr_schedular.step()
            configs.disc_lr_schedular.step()
        

        configs.curr_epoch += 1

In [None]:
img_shape = (3, 64, 64)
z_dim = 200

gen = Generator(z_dim, output_dim = 128, img_ch = img_shape[0], num_upsample = 3).to(device)
disc = Discriminator(output_channels= 128, img_channels = img_shape[0]).to(device)

configs = WLoss(gen, disc, z_dim= z_dim,
                gen_lr_gamma = 0.5,
                gen_lr_step_size = 15,

                disc_lr_gamma = 0.5,
                disc_lr_step_size=15,
                image_size = img_shape[1], 
                batch_size = 32, 
                num_epochs = 100, 
                display_metrics = 1, 
                device = device, 
                output_dir = "gen_images"
                )


In this cell, we devide or dataset into training and test sets. 

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

folder_path = r"your/dataset/path"
transform = transforms.Compose([
    transforms.Resize((configs.image_size, configs.image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

dataset = ImageFolder(folder_path, transform=transform)

# Calculer la taille du sous-ensemble désiré (50%)
subset_size = int(0.1 * len(dataset))

# Générer des indices aléatoires pour le sous-ensemble
indices = torch.randperm(len(dataset))

# Sélectionner les indices du sous-ensemble
train_set = indices[:subset_size]
test_set = indices[subset_size:]
# Créer un sampler à partir des indices du sous-ensemble
train_subset_sampler = SubsetRandomSampler(train_set)
test_subset_sampler = SubsetRandomSampler(train_set)

train_loader = DataLoader(dataset, batch_size= configs.batch_size, sampler = train_subset_sampler)
test_loader = DataLoader(dataset, batch_size= configs.batch_size, sampler = test_subset_sampler)


We will only pass the test_set trough the training because the training set is to big.

In [None]:
train_WGAN(configs, test_loader, c_lambda = 10)

In [None]:
torch.save(configs.generator.state_dict(), "gen_model.pth")
torch.save(configs.discriminator.state_dict(), "disc_model.pth")

In [None]:
import imageio

def create_gif_from_images(folder_path: str, output_gif_path: str, duration: float = 100.)->None:
    """
    Create a GIF from images in a folder and save it to the specified output path.
    
    Args:
        folder_path (str): Path to the folder containing the images.
        output_gif_path (str): Path to save the output GIF.
        duration (float): Duration (in seconds) between each image in the GIF (default is 0.5 seconds).
    """
    # List of image files in the folder
    image_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.jpg')]
    
    # List of images
    images = [imageio.imread(file) for file in image_files]
    
    # Write the GIF
    imageio.mimsave(output_gif_path, images, duration=duration)

create_gif_from_images(configs.output_dir, 'res_gif_64.gif')