## What this File Contains?

This file contains all the core and helpfull function for supervised learning and Generative Adversarial Netwrorks. With more details it contains:
* Data Preparation:
    1. `download_dataset_zip(source, destination, if_exists_stop, delete_zip)`
    2. `download_image(source, destination, if_exists_stop)`
    2. `plot_images_before_and_after(images_path_list, transforms)`
    3. `split_dataset(dataset, split_prop, seed)`
    4. `show_image_grid(images, denorm_fn, n_rows, fig_size)`

* Evaluating:
    1. `plot_loss_curve(results, fig_size, font_size)`
    2. `plot_acc_curve(results, fig_size, font_size)`
    3. `plot_image_predictions(model, class_names, image_path, transforms)`
    4. `create_writer(experiment_name, model_name, extras)`
    5. `accuracy_fn(model_logits, labels)`

* Training:
    * a) Supervised Learning:
        1. `training_step(model, train_dl, loss_fn, eval_fn, opt)`
        2. `validation_step(model, valid_dl, loss_fn, eval_fn)`
        3. `fit(model, epochs, train_dl, valid_dl, loss_fn, eval_fn, opt, writer)`

    * b) GAN:
        1. `train_discriminator(discriminator, generator, real_image_batch, loss_fn, opt_dis, opt_gen, latent_size`
        2. `train_generator(discriminator, generator, batch_size, latent_size, loss_fn, opt_dis, opt_gen)`
        3. `fit(discriminator, generator, epochs, train_dl, latent_size, loss_fn, opt_dis, opt_gen, save_gen_images_fn, generated_image_path)`

* Utilities:
    1. `save_model(model, saved_model_path, if_exists_stop)`
    2. `save_gen_images(generator, path, denorm, batch_size, latent_size, index, images_name, nrows, print_info)`
    3. `plot_decision_boundary(model, X, y)`

### Data Preparation: 1

In [87]:
def download_dataset_zip(source: str, destination: str, if_exists_stop=False, delete_zip=True):
    from pathlib import Path   # For creating the dataset directory
    from shutil import rmtree  # For deleting the dataset directory if exists
    import requests            # For geting the dataset from the `web`
    import zipfile             # For extracting he dataset `zip` file
    from os import remove      # For deleting the `zip` file


    # Creating the Path object and getting the name of the `zip` file
    dataset_path = Path(destination) / source.split('/')[-1].split('.')[0]
    dataset_name_ext = source.split('/')[-1]

    # If Path object exists then give the option to delete it or stop the execution
    if dataset_path.is_dir():
        print(f"[INFO] `{dataset_path}` already exists...")
        
        if if_exists_stop:
            return

        print(f"[INFO] Deleting `{dataset_path}`..")
        rmtree(dataset_path)

    # Creating the dataset directory
    print(f"[INFO] Creating `{dataset_path}`...")
    dataset_path.mkdir(parents=True, exist_ok=True)

    # Downloading the dataset inside the Path object
    with open(dataset_path / dataset_name_ext, "wb") as f:
        req = requests.get(source)

        print(f"[INFO] Downloading dataset: {source} to `{dataset_path}`...")
        f.write(req.content)

    # Extracting the content of the `zip` file to destination
    with zipfile.ZipFile(dataset_path / dataset_name_ext, "r") as zip_ref:
        print(f"[INFO] Unzipping dataset `{dataset_path / dataset_name_ext}` to `{dataset_path}`...")
        zip_ref.extractall(dataset_path)

    # Deleting the `zip` file
    if delete_zip:
        print(f"[INFO] Deleting `{dataset_path / dataset_name_ext}`...")
        remove(dataset_path / dataset_name_ext)

    print(f"[INFO] Dataset succesfully downloaded to `{dataset_path}`..")

    return dataset_path

### Data Preparation: 2

In [92]:
def download_image(source, destination, if_exists_stop=False):
    from pathlib import Path  # For creating the directory to store the image
    import requests           # For downloading the image from the web
    from os import remove     # For deleting the directory if exist


    # Creating the Path object and getting the name of the image
    image_path = Path('/'.join(destination.split('/')[:-1]))
    image_name = destination.split('/')[-1]

    # Overwriting the image if exists
    if (image_path / image_name).is_file():
        print(f"[INFO] Image `{destination}` already exists...")

        if if_exists_stop:
            return

        print(f"[INFO] Deleting `{destination}`...")
        remove(destination)
    
    # Creating the image directory if not exists
    if not image_path.is_dir():
        image_path.mkdir(parents=True, exist_ok=True)

    # Downloading the image from the web
    with open(image_path / image_name, "wb") as f:
        print(f"[INFO] Downloading {source} to `{destination}`...")
        req = requests.get(source)
        f.write(req.content)
    
    print(f"[INFO] `{image_name}` succesfully downloaded to `{destination}`...")

### Data Preparation: 3

In [62]:
def plot_before_and_after(image_path_list, transforms):
    from PIL import Image           # For opening the images and passing them to `pyplot`
    import matplotlib.pyplot as plt # For displaying the images

    # Formatting `pyplot` figure
    plt.figure(figsize=(5, 3*len(image_path_list)))

    i = 0
    for image_path in image_path_list:

        with Image.open(image_path) as img:
            # Plotting the original image
            i += 1
            plt.subplot(len(image_path_list), 2, i)
            plt.imshow(img)
            plt.title(f"Original Shape:\n({img.size[0]}, {img.size[1]}, 3)")
            plt.axis(False)

            # Plotting the transformed image
            i += 1
            transformed_img = transforms(img)
            plt.subplot(len(image_path_list), 2, i)
            plt.imshow(transformed_img.permute(1, 2, 0))
            plt.title(f"Transformed Shape:\n({transformed_img.size(dim=0)}, {transformed_img.size(dim=1)}, {transformed_img.size(dim=2)})")
            plt.axis(False)

            # Adding some padding for better display
            plt.subplots_adjust(left=0.2, right=0.9)

### Data Preparation: 4



In [67]:
def split_dataset(dataset, split_prop=0.2, seed=9):
    from torch.utils.data import random_split # For splitting the dataset
    from torch import manual_seed             # For setting the random seed


    length_1 = int(len(dataset) * split_prop) # desired length
    length_2 = len(dataset) - length_1        # remaining length

    print(
        f"[INFO] Splitting dataset of length {len(dataset)} into splits of size: "
        f"{length_1} ({int(split_prop*100)}%), {length_2} ({int((1-split_prop)*100)}%)"
        )

    # Creating splits using Pytorch's `random_split` with random seed
    random_split_1, random_split_2 = random_split(dataset, [length_1, length_2], generator=manual_seed(seed))

    return random_split_1, random_split_2

### Data Preparation: 5

In [97]:
def show_images(images, denorm_fn, n_rows=3, fig_size=3):
    from torchvision.utils import make_grid # For making the grid of images
    import matplotlib.pyplot as plt         # For plotting the grid

    fig, ax = plt.subplots(figsize=(fig_size, fig_size))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm_fn(images[:n_rows**2]).cpu(), nrow=n_rows).permute(1, 2, 0))

### Evaluating: 1

In [69]:
def plot_loss_curve(model_res, fig_size=(6, 4), font_size=11):
    import matplotlib.pyplot as plt # For dispaying the curves


    plt.figure(figsize=fig_size)
    
    plt.plot(range(model_res["model_epochs"]), model_res["model_valid_loss"], c='g', label="Valid Loss")
    plt.plot(range(model_res["model_epochs"]), model_res["model_train_loss"], c='b', label="Train Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(f"Loss Curves: {model_res['model_name']}", fontsize=14)
    plt.legend(fontsize=font_size)

### Evaluating: 2

In [70]:
def plot_acc_curve(model_res, fig_size=(6, 4), font_size=11):
    import matplotlib.pyplot as plt # For dispaying the curves


    plt.figure(figsize=fig_size)
    
    plt.plot(range(model_res["model_epochs"]), model_res["model_valid_acc"], c='g', label="Valid Accuracy (%)")
    plt.plot(range(model_res["model_epochs"]), model_res["model_train_acc"], c='b', label="Train Accuracy (%)")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.title(f"Accuracy Curves: {model_res['model_name']} (%)", fontsize=14)
    plt.legend(fontsize=font_size)

### Evaluating: 3


In [94]:
def plot_image_predictions(model, class_names, image_path, transforms):
    import torch
    import matplotlib.pyplot as plt
    from PIL import Image


    # Setting the model's data type and device
    model_dtype = next(model.parameters()).dtype
    model_device = next(model.parameters()).device

    # Converting the image to Tensor
    with Image.open(image_path, "r") as img:
        img_tensor = transforms(img)

    # Making Predictions
    model.eval()
    with torch.inference_mode():
        logits = model(img_tensor.unsqueeze(dim=0).to(model_device))

    pred_label = class_names[torch.softmax(logits, dim=1).argmax(dim=1)]
    label_prob = torch.softmax(logits, dim=1).max().item() * 100

    plt.imshow(img_tensor.permute(1, 2, 0).cpu())
    plt.title(f"Label: {pred_label} | Probability: {label_prob: .2f}%")
    plt.axis(False)

### Evaluating: 4

In [96]:
def create_writer(experiment_name, model_name, extras=None):
    from torch.utils.tensorboard import SummaryWriter
    from datetime import datetime
    from os.path import join

    # Loading and formating properly the current time
    timestamp = datetime.now().strftime("%Y-%m-%d")

    # Creating the `log_dir` path
    if extras:
        log_dir = join("runs", timestamp, experiment_name, model_name, extras)
    else:
        log_dir = join("runs", timestamp, experiment_name, model_name)

    print(f"[INFO] SummaryWriter created, saving to: {log_dir}...")
    
    return SummaryWriter(log_dir=log_dir)

### Evaluating: 5

In [101]:
def accuracy_fn(model_logits, labels):
    import torch
    

    labels_pred = torch.softmax(model_logits.type(torch.float32), dim=1).argmax(dim=1)
    return (torch.sum(labels_pred == labels).item() / len(labels)) * 100

### Training: a1

In [98]:
def training_step(model, train_dl, loss_fn, eval_fn, opt):
    from tqdm import tqdm # For the progress bar
    

    # Setting batch size and model's device
    batch_size = train_dl.batch_size
    model_device = next(model.parameters()).device

    # Initialize training loss and accuracy
    train_loss, train_eval = 0, 0

    print("\tTraining Step: ", end="")

    model.train()
    for x_train, y_train in tqdm(train_dl):
        # Moving batches to device
        x_train, y_train = x_train.to(model_device, non_blocking=True), y_train.to(model_device, non_blocking=True)

        # Generating predictions
        model_logits = model(x_train)

        # Calculate loss
        loss = loss_fn(model_logits, y_train)
        train_loss += loss.item()
        train_eval += eval_fn(model_logits, y_train)

        # Updating Model's parameters
        opt.zero_grad()
        loss.backward()
        opt.step()

    train_loss /= len(train_dl)
    train_eval /= len(train_dl)

    return train_loss, train_eval

### Training: a2

In [99]:
def validation_step(model, valid_dl, loss_fn, eval_fn):
    import torch
    from tqdm import tqdm

    # Setting batch size and model's device
    batch_size = valid_dl.batch_size
    model_device = next(model.parameters()).device

    # Initialize validation loss and accuracy
    valid_loss, valid_eval = 0, 0

    print("\tValidation Step: ", end="")

    model.eval()
    with torch.inference_mode():
        for x_valid, y_valid in tqdm(valid_dl):
            # Moving batches to model's device
            x_valid, y_valid = x_valid.to(model_device, non_blocking=True), y_valid.to(model_device, non_blocking=True)

            # Generate Predictions
            model_logits = model(x_valid)

            valid_loss += loss_fn(model_logits, y_valid).item()
            valid_eval += eval_fn(model_logits, y_valid)

        valid_loss /= len(valid_dl)
        valid_eval /= len(valid_dl)

        return valid_loss, valid_eval

### Training: a3

In [100]:
def fit(model, epochs, train_dl, valid_dl, loss_fn, eval_fn, opt, writer=None):
    from timeit import default_timer as timer
    import torch


    # Starting the `timer` and initialize the evaluating Lists
    start_time = timer()
    train_losses, train_evals = [], []
    valid_losses, valid_evals = [], []

    print("Starting Process...\n")
    
    for epoch in range(1, epochs + 1):
        print(f"-> Epoch: {epoch}/{epochs}")

        # Training and Evaluating the Model
        train_loss, train_eval = training_step(model, train_dl, loss_fn, eval_fn, opt)
        valid_loss, valid_eval = validation_step(model, valid_dl, loss_fn, eval_fn)

        if (epoch == epochs):
            print()
        print()
        print(
            f"   Train Loss: {train_loss:.4f} | "
            f"Train Accuracy: {train_eval:.2f}% | "
            f"Valid Loss: {valid_loss:.4f} | "
            f"Valid Accuracy (%): {valid_eval:.2f}%")
        print("-" * 99, end="\n\n")
        
        train_losses.append(train_loss)
        train_evals.append(train_eval)
        valid_losses.append(valid_loss)
        valid_evals.append(valid_eval)

        # Tracking the experiment
        if writer:
            batch_size, n_channels, height, width = next(iter(train_dl))[0].shape

            # Logging the Loss
            writer.add_scalars(main_tag="Loss",
                              tag_scalar_dict={"train_loss": train_loss, "valid_loss": valid_loss},
                              global_step=epoch)
            # Logging the Evaluation Metric
            writer.add_scalars(main_tag="Evaluation",
                              tag_scalar_dict={"train_eval": train_eval, "valid_eval": valid_eval},
                              global_step=epoch)
            # Tracking the Model Architecture
            writer.add_graph(model=model,
                             input_to_model=torch.randn(size=(batch_size, n_channels, height, width)).to(next(model.parameters()).device))
            # Closing the `writer` object
            writer.close()

    print("Process Completed Successfully...")

    return {"model_train_loss": train_losses,
        "model_train_eval": train_evals,
        "model_valid_loss": valid_losses,
        "model_valid_eval": valid_evals,
        "model_name": model.__class__.__name__,
        "model_loss_fn": loss_fn.__class__.__name__,
        "model_evaluating_m": eval_fn.__name__,
        "model_optimizer": opt.__class__.__name__,
        "model_device": next(model.parameters()).device.type,
        "model_epochs": epochs,
        "model_time": timer() - start_time}

### Training: b1

In [None]:
def train_discriminator(discriminator, generator, real_image_batch, loss_fn, opt_dis, opt_gen, latent_size):
    import torch


    # Setting Devices, Batch Size and Latent
    dis_device = next(discriminator.parameters()).device
    gen_device = next(generator.parameters()).device

    batch_size = real_image_batch.shape[0]
    latent = torch.randn(size=(batch_size, latent_size, 1, 1))

    # Setting Real and Fake Labels
    real_labels = torch.ones(batch_size, 1, device=dis_device)
    fake_labels = torch.zeros(batch_size, 1, device=dis_device)

    # Calculating Loss and Score for Real Images
    real_preds = discriminator(real_image_batch.to(dis_device))

    real_loss = loss_fn(real_preds, real_labels)
    real_score = torch.mean(real_preds).item()

    # Calculating Loss and Score for Fake Images
    fake_images = generator(latent.to(gen_device))

    fake_preds = discriminator(fake_images.to(dis_device))

    fake_loss = loss_fn(fake_preds, fake_labels)
    fake_score = torch.mean(fake_preds).item()

    # Updating Discriminator Parameters
    loss = real_loss + fake_loss

    opt_dis.zero_grad()
    opt_gen.zero_grad()
    loss.backward()
    opt_dis.step()

    return round(loss.item(), 4), round(real_score, 4), round(fake_score, 4)

### Training: b2

In [None]:
def train_generator(discriminator, generator, batch_size, latent_size, loss_fn, opt_dis, opt_gen):
    import torch


    # Setting Devices, Latent and Labels
    dis_device = next(discriminator.parameters()).device
    gen_device = next(generator.parameters()).device

    latent = torch.randn(size=(batch_size, latent_size))

    labels = torch.ones(batch_size, 1, device=dis_device)

    # Generating Fake image and Calculating Loss
    fake_image = generator(latent.to(gen_device))

    fake_pred = discriminator(fake_image.to(dis_device))

    loss = loss_fn(fake_pred, labels)

    # Updating Generator Parameters
    opt_dis.zero_grad()
    opt_gen.zero_grad()
    loss.backward()  
    opt_gen.step()

    return round(loss.item(), 4), fake_image

### Training: b3

In [None]:
def fit(discriminator, generator, epochs, train_dl, latent_size, loss_fn, opt_dis, opt_gen, save_gen_images_fn=None, generated_image_path=None):
    from timeit import default_timer as timer
    from tqdm import tqdm


    start_time = timer()

    # Setting Evaluating Lists and Batch Size
    dis_losses, gen_losses = [], []
    real_scores, fake_scores = [], []

    batch_size = train_dl.batch_size

    i = 1
    
    print("Starting Process...\n")

    for epoch in range(1, epochs + 1):
        print(f"Epoch: {epoch} | Training: ", end="")
        for real_image_batch, _ in tqdm(train_dl):
            
            # Training the 2 models
            dis_loss, real_score, fake_score = train_discriminator(discriminator,
                                                                   generator,
                                                                   real_image_batch,
                                                                   loss_fn,
                                                                   opt_dis,
                                                                   opt_gen,
                                                                   latent_size)
            gen_loss, generated_image = train_generator(discriminator,
                                                        generator,
                                                        batch_size,
                                                        latent_size,
                                                        loss_fn,
                                                        opt_dis,
                                                        opt_gen)
            
        # Updating Evaluating Lists
        dis_losses.append(dis_loss)
        gen_losses.append(gen_loss)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
    
        # Saving the Generated Image Batch
        if save_gen_images_fn and generated_image_path:
            save_gen_images_fn(generator, generated_image_path, batch_size, latent_size, i)
            i += 1

        # Printing Results
        print(
            f"\tDiscr_Loss: {dis_loss} | "
            f"Gen_Loss: {gen_loss} | "
            f"Real_Score: {real_score} | "
            f"Fake_Score: {fake_score}")
        print('-' * 93, end="\n\n")

    print("Process Completed Succesfully...")

    return dis_losses, gen_losses, real_scores, fake_scores, (timer() - start_time)

### Utilities: 1

In [95]:
def save_model(model, saved_model_path: str, if_exists_stop=False):
    from pathlib import Path
    from torch import save
    from os import remove


    # Creating the Path object and getting the name of the model
    target_path = Path('/'.join(saved_model_path.split('/')[:-1]))
    model_name = saved_model_path.split('/')[-1]

    # `model_name` should end with '.pt' or '.pth'
    assert model_name.endswith(".pth") or model_name.endswith(".pt"), "Wrong extension: Expecting `.pt` or `.pth`..."
    
    # Creating the directory that the model is going to be saved if not exists
    if not target_path.exists():
        target_path.mkdir(parents=True, exist_ok=True)

    if (target_path / model_name).is_file():
        print(f"[INFO] Model `{model_name}` already exists on `{target_path}`...")

        if if_exists_stop:
            return

        print(f"[INFO] Deleting `{target_path / model_name}`...")
        remove(target_path / model_name)

    # Saving the Model to path
    print(f"[INFO] Saving Model `{model_name}` to `{target_path}`...")
    save(obj=model.state_dict(), f=target_path/model_name)

    print(f"[INFO] Model Successfully Saved to {target_path / model_name}")

### Utilities: 2

In [None]:
# Creating a Function to Save a Grid of Images
def save_gen_images(generator, path, denorm, batch_size, latent_size, index=0, images_name="generated_images", nrows=10, print_info=False):
    import torch
    from torchvision.utils import save_image


    generator.eval()
    with torch.inference_mode():

        # Creating latent and generating the images
        latent = torch.randn(size=(batch_size, latent_size), device=next(generator.parameters()).device)
        fake_images = generator(latent).reshape(batch_size, 1, 28, 28)
        generated_image_grid_name = f"{images_name}_{index:0=4d}.jpg"
    
    if print_info:
        print(f"\t[INFO] Saving {generated_image_grid_name} to {path}/")

    save_image(denorm(fake_images)[:nrows**2], path / generated_image_grid_name, nrow=nrows)

    if print_info:
        print(f"\t[INFO] {generated_image_grid_name} saved succesfully to {path}/")

### Utilities: 3

In [None]:
def plot_decision_boundary(model, X, y):
    import torch
    import numpy as np
    import matplotlib.pyplot as plt


    """
    Plots decision boundaries of model predicting on X in comparison to y.
    Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
    """
    # Put everything to CPU (works better with NumPy + Matplotlib)
    model.to("cpu")
    X, y = X.to("cpu"), y.to("cpu")

    # Setup prediction boundaries and grid
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))

    # Make features
    X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()

    # Make predictions
    model.eval()
    with torch.inference_mode():
        y_logits = model(X_to_pred_on)

    # Test for multi-class or binary and adjust logits to prediction labels
    if len(torch.unique(y)) > 2:
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)  # mutli-class
    else:
        y_pred = torch.round(torch.sigmoid(y_logits))  # binary

    # Reshape preds and plot
    y_pred = y_pred.reshape(xx.shape).detach().numpy()
    plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
    plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())