Submitted by:

* Michal Marom- 207479940
* Noy Cohen- 206713307
* David Sonego- 201416757


In [1]:
## Mounte drive
from google.colab import drive
import numpy as np
drive.mount('/content/gdrive')

Mounted at /content/gdrive


## Hyper Parameters

In [2]:
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 0.0001
LR = 0.0001
# note_book_save_path = '/content/gdrive/MyDrive/Colab Notebooks/ImageColoringProject'
note_book_save_path = '/content/gdrive/MyDrive/ImageColoringProject'

## Imports

In [None]:
import pickle
from torch.utils.data import Subset
import random
import shutil
from torchvision.datasets import ImageFolder
import cv2
import os
import torch
import glob
import time
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
from torchvision.datasets import Flowers102
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None

## Data Loader

In [None]:
def data_loader(color_mode='gray', batch_size=32):
    # Define transformations for RGB images
    rgb_transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize images to a fixed size
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize images
    ])

    # Loading the RGB dataset
    rgb_train = Flowers102(root='./data', download=True, split='test', transform=rgb_transform)
    rgb_test = Flowers102(root='./data', split='train', transform=rgb_transform)
    rgb_val = Flowers102(root='./data', split="val", transform=rgb_transform)

    # Creating RGB dataloaders
    train_loader_rgb = torch.utils.data.DataLoader(rgb_train, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    eval_loader_rgb = torch.utils.data.DataLoader(rgb_val, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
    test_loader_rgb = torch.utils.data.DataLoader(rgb_test, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

    data = (train_loader_rgb, eval_loader_rgb, test_loader_rgb)
    return data

## Unet and Gan

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x


class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p


class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x


# Define Generator (U-Net) architecture with VGG blocks
class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        """ Encoder """
        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)
        """ Bottleneck """
        self.b = conv_block(512, 1024)
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)
        """ Classifier """
        self.outputs = nn.Conv2d(64, 3, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        """ Bottleneck """
        b = self.b(p4)
        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        """ Classifier """
        outputs = self.outputs(d4)
        return outputs

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

## Plots

In [None]:
def plot_graph(loss, title, x_label='Batch', y_label='Loss'):
    plt.cla()
    plt.plot(range(len(loss)), loss, label=title)
    plt.xlabel(f'{x_label} Steps - axis')
    plt.ylabel(f'{y_label} Value - axis')
    plt.title(title)

    plt.legend()
    plt.savefig(f'{note_book_save_path}/results/graph_{title}')
    plt.close()
    return

## PSNR accuracy

In [None]:
def prepare_to_save_image(image):
    image_np = image.permute(1, 2, 0).detach().cpu().numpy()
    return (image_np - image_np.min()) / (image_np.max() - image_np.min())


def make_subplot(rbg_image, grey_image, gen_image):
    # Plot images
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(rbg_image)
    axs[0].axis('off')

    axs[1].imshow(grey_image.squeeze(), cmap='gray')
    axs[1].axis('off')

    axs[2].imshow(gen_image)
    axs[2].axis('off')

    plt.tight_layout()
    return plt


# Computes the PSNR between the input and target images.
def psnr(input_image, target_image):
    mse = torch.mean((input_image - target_image) ** 2)
    psnr_val = 10 * torch.log10(1 / mse)
    return psnr_val.detach().numpy()


def compute_psnr(loader_gray, loader_rgb, model_handler):
    psnr_values = []
    with torch.no_grad():
        for (gray_images, _), (rgb_images, _) in zip(loader_gray, loader_rgb):
            gray_images = gray_images.to("cuda")
            rgb_images = rgb_images.to("cuda")
            gen_images = model_handler.generator(gray_images)
            psnr_values.extend([psnr(gen_img, rgb_img) for gen_img, rgb_img in zip(gen_images.to("cpu"), rgb_images.to("cpu"))])
    avg_psnr = sum(psnr_values) / len(psnr_values)
    return avg_psnr

In [None]:
def convert_to_greyscale_batch(image_batch, grey_dir=None):
    # Ensure image_batch is a tensor
    if not torch.is_tensor(image_batch):
        image_batch = torch.tensor(image_batch)

    # Check if the input is in [B, C, H, W] format
    if len(image_batch.shape) != 4:
        raise ValueError("Input image_batch should be in [B, C, H, W] format.")

    # Convert the batch of colored images to greyscale by taking the mean along the channel dimension
    grey_batch = torch.mean(image_batch, dim=1, keepdim=True)  # Assuming image_batch is in [B, C, H, W] format

    return grey_batch

def average_every_n_epochs(data, n=5):
    num_epochs = len(data)
    num_batches = num_epochs // n
    averaged_data = []
    for i in range(num_batches):
        start = i * n
        end = (i + 1) * n
        averaged_data.append(np.mean(data[start:end]))
    return averaged_data

## Model Handler

In [None]:
class ModelHandler:
    def __init__(self, train_loader_rgb, eval_loader_rgb, test_loader_rgb,
                 batch_size=BATCH_SIZE, num_epochs=EPOCHS, lr_G=LR, lr_C=LR, num_epochs_pre=EPOCHS):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_epochs = num_epochs
        self.num_epochs_pre = num_epochs_pre
        self.lr_G = lr_G
        self.lr_C = lr_C
        self.batch_size = batch_size
        self.train_loader_rgb = train_loader_rgb
        self.eval_loader_rgb = eval_loader_rgb
        self.test_loader_rgb = test_loader_rgb
        self.MSEcriterion = nn.MSELoss()
        # WGAN
        self.generator = UNetGenerator().to(self.device)
        self.Critic = Critic().to(self.device)
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=self.lr_G, betas=(0, 0.9))
        self.optimizer_C = optim.Adam(self.Critic.parameters(), lr=self.lr_C, betas=(0, 0.9))
        # self.note_book_save_path = '/content/gdrive/MyDrive/Colab Notebooks/ImageColoringProject'
        self.note_book_save_path = note_book_save_path

    def pretrain_generator(self):
        pretrained_model_path = f'{self.note_book_save_path}/saved_models/pretrained_model.pth'
        if os.path.exists(pretrained_model_path):
            self.generator.load_state_dict(torch.load(pretrained_model_path))
            print("Finished loading the previous pretrained model")
        else:
            print("Starting to pretrain the generator!")

        accuracy, g_loss_per_epoch, avg_psnr_per_epoch = self.load_pretrained_arrays()

        print("Starts to pretrain!")
        self.generator.train()
        for epoch in range(len(accuracy), self.num_epochs_pre):
            psnr_values = []
            g_loss_per_batch = []
            for batch_idx, (rgb_images, _) in enumerate(self.train_loader_rgb):

                rgb_images = rgb_images.to(self.device)
                gray_images = convert_to_greyscale_batch(rgb_images).to(self.device)

                gen_images = self.generator(gray_images)
                loss = self.MSEcriterion(gen_images, rgb_images)

                self.optimizer_G.zero_grad()
                loss.backward()
                self.optimizer_G.step()

                if batch_idx % 20 == 0:
                    self.save_pretrained_images(gen_images, gray_images, rgb_images, epoch, batch_idx)
                    g_loss_per_batch.append(loss)


                psnr_values.extend([psnr(gen_img, rgb_img) for gen_img, rgb_img in zip(gen_images.to("cpu"), rgb_images.to("cpu"))])
                print("[Epoch %d/%d] [Batch %d/%d] [G loss: %f]" % (
                    epoch, self.num_epochs_pre, batch_idx, len(self.train_loader_rgb), loss.item()))

                rgb_images = rgb_images.to("cpu")
                gen_images = gen_images.to("cpu")
                gray_images = gray_images.to("cpu")

            # Calculate the average PSNR over all images
            avg_psnr = sum(psnr_values) / len(psnr_values)
            avg_psnr_per_epoch.append(avg_psnr)
            accuracy.append(avg_psnr)
            g_loss_per_epoch.append(np.average([l.item() for l in g_loss_per_batch]))

            print("[Epoch: %d/%d] [g_loss_train: %f] [PSNR: %.2f dB]" % (
                epoch, self.num_epochs_pre, np.average([l.item() for l in g_loss_per_batch]), avg_psnr))
            self.save_pretrained_model(pretrained_model_path)
            self.save_pretrained_arrays(accuracy, g_loss_per_epoch, avg_psnr_per_epoch)

            # Clean up pretrain tensors
            torch.cuda.empty_cache()
            gc.collect()

        return g_loss_per_epoch, accuracy, avg_psnr_per_epoch

    def load_pretrained_arrays(self):
        save_dir = f'{self.note_book_save_path}/saved_models'
        os.makedirs(save_dir, exist_ok=True)

        def load_array(filename):
            return list(np.load(filename)) if os.path.exists(filename) else []

        return (
            load_array(f'{save_dir}/pretrained_accuracy.npy'),
            load_array(f'{save_dir}/pretrained_g_loss_per_epoch.npy'),
            load_array(f'{save_dir}/pretrained_avg_psnr_per_epoch.npy')
        )

    def save_pretrained_images(self, gen_images, gray_images, rgb_images, epoch, batch_idx):
        os.makedirs(f'{self.note_book_save_path}/pre_trained_images', exist_ok=True)
        first_image_gen = prepare_to_save_image(gen_images[0])
        first_image_grey = prepare_to_save_image(gray_images[0])
        first_image_rbg = prepare_to_save_image(rgb_images[0])
        plt = make_subplot(first_image_rbg, first_image_grey, first_image_gen)
        plt.savefig(f'{self.note_book_save_path}/pre_trained_images/pre_trained_image_{epoch}_{batch_idx}.jpg')
        plt.close()

    def save_pretrained_model(self, pretrained_model_path):
        torch.save(self.generator.state_dict(), pretrained_model_path)

    def save_pretrained_arrays(self, accuracy, g_loss_per_epoch, avg_psnr_per_epoch):
        save_dir = f'{self.note_book_save_path}/saved_models'
        np.save(f'{save_dir}/pretrained_accuracy.npy', np.array(accuracy))
        np.save(f'{save_dir}/pretrained_g_loss_per_epoch.npy', np.array(g_loss_per_epoch))
        np.save(f'{save_dir}/pretrained_avg_psnr_per_epoch.npy', np.array(avg_psnr_per_epoch))

    def test_model(self, loader_rgb):
        # Set model to eval mode
        self.generator.eval()
        self.Critic.eval()
        psnr_values_per_batch = []

        for batch_idx, (rgb_images, _) in enumerate(loader_rgb):
            rgb_images = rgb_images.to(self.device)
            gray_images = convert_to_greyscale_batch(rgb_images).to(self.device)

            # Generate rgb images
            gen_images = self.generator(gray_images)

            # Calculate PSNR values for each generated image
            psnr_values_per_batch.extend([psnr(gen_img, rgb_img) for gen_img, rgb_img in zip(gen_images.detach().to("cpu"), rgb_images.to("cpu"))])

            rgb_images = rgb_images.to("cpu")
            gen_images = gen_images.to("cpu")
            gray_images = gray_images.to("cpu")

        return psnr_values_per_batch

    def val_model(self, loader_rgb):
        # Set model to eval mode
        self.generator.eval()
        self.Critic.eval()
        c_loss_per_batch = []
        g_loss_per_batch = []
        mse_loss_per_batch = []

        for batch_idx, (rgb_images, _) in enumerate(loader_rgb):
            rgb_images = rgb_images.to(self.device)
            gray_images = convert_to_greyscale_batch(rgb_images).to(self.device)

            # Calculate Generator loss
            gen_images = self.generator(gray_images)
            wgan_loss = -torch.mean(self.Critic(gen_images))
            mse_loss = self.MSEcriterion(rgb_images, gen_images)
            g_loss = wgan_loss * 0.3 + mse_loss * 0.7

            # Calculate Critic loss
            c_loss = -torch.mean(self.Critic(rgb_images)) + torch.mean(self.Critic(gen_images.detach()))

            # Save Critic and Generator loss per batch
            c_loss_per_batch.append(c_loss.item())
            g_loss_per_batch.append(g_loss.item())
            mse_loss_per_batch.append(mse_loss.item())


            rgb_images = rgb_images.to("cpu")
            gen_images = gen_images.to("cpu")
            gray_images = gray_images.to("cpu")

        c_loss_avr = sum(c_loss_per_batch)/len(c_loss_per_batch)
        g_loss_avr = sum(g_loss_per_batch)/len(g_loss_per_batch)
        mse_loss_avr = sum(mse_loss_per_batch) / len(mse_loss_per_batch)


        # Reset model to train mode
        self.generator.train()
        self.Critic.train()

        return c_loss_avr, g_loss_avr, mse_loss_avr


    def results_visualization(self):
        counter = 0
        for batch_idx, (rgb_images, _) in enumerate(self.test_loader_rgb):

            rgb_images = rgb_images.to(self.device)
            gray_images = convert_to_greyscale_batch(rgb_images).to(self.device)

            # Generate RGB images from grayscale
            gen_images = self.generator(gray_images)

            for idx, (gray_image, rgb_image, gen_image) in enumerate(zip(gray_images, rgb_images, gen_images)):
                gray_image_np = prepare_to_save_image(gray_image)
                gen_image_np = prepare_to_save_image(gen_image)
                rgb_image_np = prepare_to_save_image(rgb_image)
                plt = make_subplot(rgb_image_np, gray_image_np, gen_image_np)

                plt.savefig(f'{self.note_book_save_path}/results/image_%d.jpg' % counter)
                plt.close()
                counter += 1

            rgb_images = rgb_images.to("cpu")
            gen_images = gen_images.to("cpu")
            gray_images = gray_images.to("cpu")

    def gradient_penalty(self, real_images, fake_images):
        device = real_images.device
        alpha = torch.rand(real_images.size(0), 1, 1, 1, device=device)
        alpha.expand_as(real_images)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
        pred = self.Critic(interpolated)
        gradients = torch.autograd.grad(outputs=pred, inputs=interpolated,
                                        grad_outputs=torch.ones(pred.size(), device=device),
                                        create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def save_generated_images(self, gen_images, gray_images, rgb_images, epoch, batch_idx):
        os.makedirs(f'{self.note_book_save_path}/generated_images', exist_ok=True)
        first_image_gen = prepare_to_save_image(gen_images[0])
        first_image_grey = prepare_to_save_image(gray_images[0])
        first_image_rbg = prepare_to_save_image(rgb_images[0])
        plt = make_subplot(first_image_rbg, first_image_grey, first_image_gen)
        plt.savefig(f'{self.note_book_save_path}/generated_images/image_epoch_{epoch}batch{batch_idx}.jpg')
        plt.close()

    def load_models(self):
        if os.path.exists(f'{self.note_book_save_path}/saved_models/generator_model.pth') and os.path.exists(
                f'{self.note_book_save_path}/saved_models/Critic_model.pth'):
            self.generator.load_state_dict(torch.load(f'{self.note_book_save_path}/saved_models/generator_model.pth'))
            self.Critic.load_state_dict(torch.load(f'{self.note_book_save_path}/saved_models/Critic_model.pth'))
            print("Finished loading the previous trained models!")
        elif os.path.exists(f'{self.note_book_save_path}/saved_models/pretrained_model.pth'):
            self.generator.load_state_dict(torch.load(f'{self.note_book_save_path}/saved_models/pretrained_model.pth'))
            print("Finished loading the pretrained generator!")
        else:
            print("Starting to train without pretrained model!")

    def save_arrays(self, c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c, mse_losses_per_epoch, wgan_losses_per_epoch,val_losses_mse):
        np.save(f'{self.note_book_save_path}/saved_models/c_loss_per_epoch.npy', np.array(c_loss_per_epoch))
        np.save(f'{self.note_book_save_path}/saved_models/g_loss_per_epoch.npy', np.array(g_loss_per_epoch))
        np.save(f'{self.note_book_save_path}/saved_models/accuracy.npy', np.array(accuracy))
        np.save(f'{self.note_book_save_path}/saved_models/val_losses_g.npy', np.array(val_losses_g))
        np.save(f'{self.note_book_save_path}/saved_models/val_losses_c.npy', np.array(val_losses_c))
        np.save(f'{self.note_book_save_path}/saved_models/val_losses_mse.npy', np.array(val_losses_mse))
        np.save(f'{self.note_book_save_path}/saved_models/wgan_losses_per_epoch.npy', np.array(wgan_losses_per_epoch))
        np.save(f'{self.note_book_save_path}/saved_models/mse_losses_per_epoch.npy', np.array(mse_losses_per_epoch))
        torch.save(self.generator.state_dict(), f'{self.note_book_save_path}/saved_models/generator_model.pth')
        torch.save(self.Critic.state_dict(), f'{self.note_book_save_path}/saved_models/Critic_model.pth')

    def load_arrays(self):
        def load_array(filename):
            return list(np.load(filename)) if os.path.exists(filename) else []

        return(
            load_array(f'{self.note_book_save_path}/saved_models/c_loss_per_epoch.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/g_loss_per_epoch.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/accuracy.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/val_losses_g.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/val_losses_c.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/val_losses_mse.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/wgan_losses_per_epoch.npy'),
            load_array(f'{self.note_book_save_path}/saved_models/mse_losses_per_epoch.npy')
            # load_array(f'{self.note_book_save_path}/saved_models/test_accuracy.npy')
        )

    def train(self):
        # Load previous trained models and arrays
        self.load_models()

        # Initialize arrays
        c_loss_per_epoch = []
        g_loss_per_epoch = []
        accuracy = []
        val_losses_g = []
        val_losses_c = []
        val_losses_mse = []
        mse_losses_per_epoch = []
        wgan_losses_per_epoch = []
        train_critic = False

        # Load arrays
        # c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c,val_losses_mse, wgan_losses_per_epoch, mse_losses_per_epoch, test_accuracy = self.load_arrays()
        c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c,val_losses_mse, wgan_losses_per_epoch, mse_losses_per_epoch = self.load_arrays()

        # Configure to train mode
        self.generator.train()
        self.Critic.train()

        # Train the model
        for epoch in range(len(c_loss_per_epoch), self.num_epochs):
            psnr_values = []
            g_loss_per_batch = []
            c_loss_per_batch = []
            mse_losses_per_batch = []
            wgan_losses_per_batch = []
            for batch_idx, (rgb_images, _) in enumerate(self.train_loader_rgb):

                rgb_images = rgb_images.to(self.device)
                gray_images = convert_to_greyscale_batch(rgb_images).to(self.device)

                # Training the critic every 4 steps
                if batch_idx % 4 == 0 and batch_idx != 0:
                    train_critic = True
                    # Train the critic
                    gen_images = self.generator(gray_images)
                    self.optimizer_C.zero_grad()
                    loss_c = -torch.mean(self.Critic(rgb_images)) + torch.mean(self.Critic(gen_images.detach()))
                    gp = self.gradient_penalty(rgb_images, gen_images.detach())
                    loss_c += 10 * gp
                    loss_c.backward()
                    self.optimizer_C.step()

                # Training the generator
                for param in self.Critic.parameters():
                    param.requires_grad = False

                self.optimizer_G.zero_grad()
                gen_images = self.generator(gray_images)
                wgan_loss = -torch.mean(self.Critic(gen_images))
                mse_loss = self.MSEcriterion(rgb_images, gen_images)
                loss_g = wgan_loss * 0.3 + mse_loss * 0.7
                loss_g.backward()
                self.optimizer_G.step()

                for param in self.Critic.parameters():
                    param.requires_grad = True

                if train_critic:
                    c_loss_per_batch.append(loss_c)
                    g_loss_per_batch.append(loss_g)
                    mse_losses_per_batch.append(mse_loss.item())
                    wgan_losses_per_batch.append(wgan_loss.item())

                    psnr_values.extend([psnr(gen_img, rgb_img) for gen_img, rgb_img in zip(gen_images.detach().to("cpu"), rgb_images.to("cpu"))])

                    # Print loss
                    print("[Epoch %d/%d] [Batch %d/%d] [Critic loss: %f] [G loss: %f] [PSNR accuracy: %f] "
                          % (epoch, self.num_epochs, batch_idx, len(self.train_loader_rgb), loss_c.item(),
                             loss_g.item(), sum(psnr_values) / len(psnr_values)))

                    # Save generated images
                    self.save_generated_images(gen_images, gray_images, rgb_images, epoch, batch_idx)
                    train_critic = False

                rgb_images = rgb_images.to("cpu")
                gray_images = gray_images.to("cpu")
                gen_images = gen_images.to("cpu")

                # Free up pretrain tensors
                torch.cuda.empty_cache()

            # Update Validation losses arrays
            c_loss_val, g_loss_val, mse_loss_val  = self.val_model(self.eval_loader_rgb)
            val_losses_c.append(c_loss_val)
            val_losses_g.append(g_loss_val)
            val_losses_mse.append(mse_loss_val)

            # Calculate model accuracy
            accuracy.append(sum(psnr_values) / len(psnr_values))

            # Calculate loss per epoch
            g_loss_per_epoch.append(np.average([l.item() for l in g_loss_per_batch]))
            c_loss_per_epoch.append(np.average([l.item() for l in c_loss_per_batch]))
            mse_losses_per_epoch.append(np.average([l for l in mse_losses_per_batch]))
            wgan_losses_per_epoch.append(np.average([l for l in wgan_losses_per_batch]))

            # Save arrays
            self.save_arrays(c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c, mse_losses_per_epoch, wgan_losses_per_epoch, val_losses_mse)

            # Free up pretrain tensors
            torch.cuda.empty_cache()

        # Update Test losses arrays
        if not os.path.exists(f'{self.note_book_save_path}/saved_models/test_accuracy.pth'):
          test_accuracy = self.test_model(self.test_loader_rgb)
          np.save(f'{self.note_book_save_path}/saved_models/test_accuracy.npy', np.array(test_accuracy))

        return c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c, mse_losses_per_epoch, wgan_losses_per_epoch, test_accuracy, val_losses_mse

    def return_arrays(self):
        c_loss_per_epoch, g_loss_per_epoch, accuracy, test_accuracy, val_accuracy, test_losses_g, val_losses_g, mse_losses_per_epoch, wgan_losses_per_epoch = self.load_arrays()
        return c_loss_per_epoch, g_loss_per_epoch, accuracy, test_accuracy, val_accuracy, test_losses_g, val_losses_g, mse_losses_per_epoch, wgan_losses_per_epoch


## Main

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [None]:
data = data_loader()
train_loader_rgb, eval_loader_rgb, test_loader_rgb = data
print("Finished data loading!")

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to data/flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:12<00:00, 28659263.12it/s]


Extracting data/flowers-102/102flowers.tgz to data/flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to data/flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 464594.13it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to data/flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 12353787.12it/s]

Finished data loading!





In [None]:
# Define and initialize your model handler
model_handler = ModelHandler(train_loader_rgb,eval_loader_rgb, test_loader_rgb,
                              batch_size=BATCH_SIZE, num_epochs=EPOCHS, lr_G=LR, lr_C=LR, num_epochs_pre=4)
print("Finished ModelHandler!")

Finished ModelHandler!


In [None]:
# Train Model
start_time = time.time()
model_handler.pretrain_generator()
end_time = time.time()
# Calculate elapsed time in seconds
elapsed_time = end_time - start_time
print("Elapsed time in seconds:", elapsed_time)

Finished loading the previous pretrained model
Starts to pretrain!
Elapsed time in seconds: 3.228003740310669


In [None]:
# Define Time
start_time = time.time()
c_loss_per_epoch, g_loss_per_epoch, accuracy, val_losses_g, val_losses_c, mse_losses_per_epoch, wgan_losses_per_epoch, test_accuracy, val_losses_mse = model_handler.train()
end_time = time.time()
# Calculate elapsed time in seconds
elapsed_time = end_time - start_time
print("Elapsed time in seconds:", elapsed_time)

Finished loading the previous trained models!
Elapsed time in seconds: 29.61925959587097


## Save Graphs

In [None]:
# model_handler.results_visualization()
# # plots
# plot_graph(test_accuracy, "Test Accuracy", x_label='Batch', y_label="Accuracy")
# plot_graph(g_loss_per_epoch, "Train Generator Loss", x_label='Epoch', y_label='Loss')
# plot_graph(c_loss_per_epoch, "Train Critic Loss", x_label='Epoch', y_label='Loss')
# plot_graph(accuracy, "Train Accuracy", x_label='Epoch', y_label="Accuracy")
# plot_graph(val_losses_g, "Validation Generator Loss", x_label='Epoch', y_label="Loss")
# plot_graph(val_losses_c, "Validation Critic Loss", x_label='Epoch', y_label="Loss")
# plot_graph(val_losses_mse, "Validation MSE Loss", x_label='Epoch', y_label="Loss")
# plot_graph(mse_losses_per_epoch, "Train MSE Loss", x_label='Epoch', y_label="Loss")
# plot_graph(wgan_losses_per_epoch, "Train Wasserstein Loss", x_label='Epoch', y_label="Loss")

In [None]:
# Averaging every 5 epochs for each metric

# test_accuracy_avg = average_every_n_epochs(test_accuracy, n=3)
# c_loss_per_epoch_avg = average_every_n_epochs(c_loss_per_epoch)
# g_loss_per_epoch_avg = average_every_n_epochs(g_loss_per_epoch)
# accuracy_avg = average_every_n_epochs(accuracy)
# val_losses_g_avg = average_every_n_epochs(val_losses_g)
# val_losses_c_avg = average_every_n_epochs(val_losses_c)
# val_losses_mse_avg = average_every_n_epochs(val_losses_mse, n=10)
# mse_losses_per_epoch_avg = average_every_n_epochs(mse_losses_per_epoch)
# wgan_losses_per_epoch_avg = average_every_n_epochs(wgan_losses_per_epoch)

# Plotting the averaged data
# plot_graph(test_accuracy_avg, "Averaged Test Accuracy", x_label='Batch', y_label="Accuracy")
# plot_graph(g_loss_per_epoch_avg, "Averaged Train Generator Loss", x_label='Epoch', y_label='Loss')
# plot_graph(c_loss_per_epoch_avg, "Averaged Train Critic Loss", x_label='Epoch', y_label='Loss')
# plot_graph(accuracy_avg, "Averaged Train Accuracy", x_label='Epoch', y_label="Accuracy")
# plot_graph(val_losses_g_avg, "Averaged Validation Generator Loss", x_label='Epoch', y_label="Loss")
# plot_graph(val_losses_c_avg, "Averaged Validation Critic Loss", x_label='Epoch', y_label="Loss")
# plot_graph(val_losses_mse_avg, "Averaged Validation MSE Loss", x_label='Epoch', y_label="Loss")
# plot_graph(mse_losses_per_epoch_avg, "Averaged Train MSE Loss", x_label='Epoch', y_label="Loss")
# plot_graph(wgan_losses_per_epoch_avg, "Averaged Train Wasserstein Loss", x_label='Epoch', y_label="Loss")

## Run model

In [None]:
accuracy = model_handler.test_model(test_loader_rgb)
print(f"Test accuracy: {np.mean(accuracy)}")
model_handler.results_visualization()

In [None]:
from google.colab import runtime
runtime.unassign()