In [None]:
import base64
import io

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# Explore the Dataset

In [None]:
#maa = pd.read_csv('../data/metalCover/metal_albums_artwork_images.csv', delimiter=',', quotechar='"')
#maa.describe()

In [None]:
#maa[maa['artist_name']=='Sonata Arctica']

In [None]:
# display an album image
#b64_cover = maa[maa['album_name']=='Winterheart\'s Guild'].iloc[0]['album_cover_image']
#cover_data = base64.b64decode(b64_cover)
#image = Image.open(io.BytesIO(cover_data))

#maxwidth = 128
#maxheight = 128

#ratio_width = maxwidth/image.size[0]
#ratio_height = maxheight/image.size[1]
#image = image.resize((int(image.size[0]*ratio_width),int(image.size[1]*ratio_height)), Image.ANTIALIAS)

#img_array = np.array(image)
#plt.imshow(img_array)

# Clean Dataset

In [None]:
def cleanDataset(metalDataset):
    '''
    Function for cleaning the NA values of artist_main_genre,artist_name and album_cover_image
    Parameters:
            metalDataset: Dataset from https://www.kaggle.com/benjamnmachn/metal-album-artwork-dataset-intro/data as a df
            
    '''
    metalDatasetClean = metalDataset[metalDataset['artist_main_genre'].notna()]
    metalDatasetClean = metalDatasetClean[metalDatasetClean['artist_name'].notna()]
    metalDatasetClean = metalDatasetClean[metalDatasetClean['album_cover_image'].notna()]
    
    metalDatasetClean['artist_main_genre_label'] = metalDatasetClean['artist_main_genre']
    metalDatasetClean.artist_main_genre_label = pd.Categorical(pd.factorize(metalDatasetClean.artist_main_genre_label)[0])
    
    
    return metalDatasetClean



In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def show_tensor_images(image_tensor, num_images=4, size=(3, 128, 128)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    #image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    
    image_grid[image_grid < 0] = 0
    plt.imshow(image_grid.permute(1, 2, 0).squeeze(), cmap="gray")

    plt.show()

# Helper for debugging

In [None]:
class PrintBlock(nn.Module):
    def forward(self, x):
        print("Printblockoutput {}".format(x.shape))
        return x

# Generator

In [None]:
class GeneratorConditional(nn.Module):
    def __init__(self,n_classes, embedding_dim, latent_dim):
        super(GeneratorConditional, self).__init__()
        
     
        self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),
                                                         nn.Linear(embedding_dim, 16))
        
    
        self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),
                                    nn.LeakyReLU(0.2, inplace=True))
           

        self.model = nn.Sequential(nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
                                   nn.BatchNorm2d(64*8, momentum=0.1,  eps=0.8),
                                   nn.ReLU(True),
                                   nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1,bias=False),
                                   nn.BatchNorm2d(64*4, momentum=0.1,  eps=0.8),
                                   nn.ReLU(True), 
                                   nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1,bias=False),
                                   nn.BatchNorm2d(64*2, momentum=0.1,  eps=0.8),
                                   nn.ReLU(True), 
                                   nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1,bias=False),
                                   nn.BatchNorm2d(64*1, momentum=0.1,  eps=0.8),
                                   nn.ReLU(True), 
                                   nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
                                   nn.Tanh())

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512,4,4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)

        return image

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
      n_samples: the number of samples to generate, a scalar
      z_dim: the dimension of the noise vector, a scalar
      device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)

# Critic

In [None]:
class CriticConditional(nn.Module):
    def __init__(self, n_classes, embedding_dim):
        super(CriticConditional, self).__init__()
        
    
        self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),
                                                  nn.Linear(embedding_dim, 3*128*128))
             
        self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),
                                   nn.LeakyReLU(0.2, inplace=True),
                                   nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),
                                   nn.BatchNorm2d(64*2, momentum=0.1,  eps=0.8),
                                   nn.LeakyReLU(0.2, inplace=True),
                                   nn.Conv2d(64*2, 64*4, 4, 3,2, bias=False),
                                   nn.BatchNorm2d(64*4, momentum=0.1,  eps=0.8),
                                   nn.LeakyReLU(0.2, inplace=True),
                                   nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),
                                   nn.BatchNorm2d(64*8, momentum=0.1,  eps=0.8),
                                   nn.LeakyReLU(0.2, inplace=True), 
                                   nn.Flatten(),
                                   nn.Dropout(0.4),
                                   nn.Linear(4608, 1),
                                   nn.Sigmoid()
                     )

    def forward(self, inputs):
        img, label = inputs
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        if len(img.shape) == 3:
            img = img[:, :, :, None]
            img = img.expand(img.shape[0], img.shape[1], img.shape[2], 3)
                
        if img.shape[1] != 3:
            img = img.permute(0, 3, 1, 2)
        
        concat = torch.cat((img, label_output), dim=1)
        output = self.model(concat)

        return output


# Hyperparameters

In [None]:
n_epochs = 1000
z_dim = 100
display_step = 5000
batch_size = 4
lr = 0.0005
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 10
label_dim = 2
latent_dim = 32
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

# Dataset

In [None]:
import os
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

class CustomImageDataset(Dataset):
    def __init__(self, metalDatasetPath, imageAmount = 5000, transform=None, target_transform=None):
        self.metalDataset = pd.read_csv(metalDatasetPath, delimiter=',', quotechar='"', nrows = imageAmount)
        self.metalDataset = self.cleanDataset(self.metalDataset)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.metalDataset)

    def __getitem__(self, idx):
        
        # display an album image
        b64_cover = self.metalDataset['album_cover_image'].iloc[idx]
        cover_data = base64.b64decode(b64_cover)
        image = Image.open(io.BytesIO(cover_data))
        
        #Resize the image
        maxwidth = 128
        maxheight = 128
        ratio_width = maxwidth/image.size[0]
        ratio_height = maxheight/image.size[1]
        image = image.resize((128,128))
        
        
        img_array = np.array(image).astype(int)
        
        label = self.metalDataset['artist_main_genre_label'].iloc[idx]
        
        if self.transform:
            img_array = self.transform(img_array)
        if self.target_transform:
            label = self.target_transform(label)
        return img_array, label
    
    def cleanDataset(self, metalDataset):
        '''
        Function for cleaning the NA values of artist_main_genre,artist_name and album_cover_image
        Parameters:
                metalDataset: Dataset from https://www.kaggle.com/benjamnmachn/metal-album-artwork-dataset-intro/data as a df
            
        '''
        metalDatasetClean = self.metalDataset[self.metalDataset['artist_main_genre'].notna()]
        metalDatasetClean = metalDatasetClean[metalDatasetClean['artist_name'].notna()]
        metalDatasetClean = metalDatasetClean[metalDatasetClean['album_cover_image'].notna()]
        metalDatasetClean['artist_main_genre_label'] = metalDatasetClean['artist_main_genre']
        metalDatasetClean.artist_main_genre_label = pd.Categorical(pd.factorize(metalDatasetClean.artist_main_genre_label)[0])

        metalDatasetClean = metalDatasetClean.loc[metalDatasetClean['artist_main_genre_label'].isin([0,3])]

        return metalDatasetClean

# Dataloader

In [None]:
from torch.utils.data import DataLoader

train_dataset = CustomImageDataset('../data/metalCover/metal_albums_artwork_images.csv')

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

In [None]:
# Display image and label.
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

# Initialize Generator and Critic

In [None]:
#n_classes, embedding_dim, latent_dim
generator = GeneratorConditional(n_classes = 4, embedding_dim = 16, latent_dim = 32).to(device)
G_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))
discriminator = CriticConditional(n_classes = 4, embedding_dim = 16).to(device) 
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta_1, beta_2))

In [None]:
from torch.autograd import Variable

In [None]:
def generator_loss(label, fake_output):
    gen_loss = torch.nn.functional.binary_cross_entropy(label, fake_output)
    return gen_loss

In [None]:
def discriminator_loss(label, output):
    disc_loss = torch.nn.functional.binary_cross_entropy(label, output)
    return disc_loss


In [None]:
def get_gradient(crit, real, fake, epsilon, label):
    '''
    Return the gradient of the critic's scores with respect to mixes of real and fake images.
    Parameters:
        crit: the critic model
        real: a batch of real images
        fake: a batch of fake images
        epsilon: a vector of the uniformly random proportions of real/fake per mixed image
    Returns:
        gradient: the gradient of the critic's scores, with respect to the mixed image
    '''
    # Mix the images together

    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit((mixed_images, label))
    
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        # Note: You need to take the gradient of outputs with respect to inputs.
        # This documentation may be useful, but it should not be necessary:
        # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad

        inputs=mixed_images,
        outputs=mixed_scores,

        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

In [None]:
def gradient_penalty(gradient):
    '''
    Return the gradient penalty, given a gradient.
    Given a batch of image gradients, you calculate the magnitude of each image's gradient
    and penalize the mean quadratic distance of each magnitude to 1.
    Parameters:
        gradient: the gradient of the critic's scores, with respect to the mixed image
    Returns:
        penalty: the gradient penalty
    '''
    # Flatten the gradients so that each row captures one image
    gradient = gradient.reshape(len(gradient), -1) #view

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    
    # Penalize the mean squared distance of the gradient norms from 1

    penalty = torch.mean((gradient_norm -1 )**2)

    return penalty

# Train the model

In [None]:
num_epochs = 200
for epoch in range(1, num_epochs+1): 
    cur_step = 1
    
    for real, _ in tqdm(train_loader):
        cur_batch_size = len(real)
        real = real.to(device)
        
        
        D_loss_list, G_loss_list = [], []

        for index, (real_images, labels) in enumerate(train_loader):
            D_optimizer.zero_grad()
            real_images = real_images.to(device)
            labels = labels.to(device)
            labels = labels.unsqueeze(1).long()


            real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
            fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

            D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

            noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)  
            noise_vector = noise_vector.to(device)


            generated_image = generator((noise_vector, labels))
            output = discriminator((generated_image.detach(), labels))
            D_fake_loss = discriminator_loss(output,  fake_target)


            # train with fake
            epsilon = torch.rand(len(real_images), 1, 1, 1, device=device, requires_grad=True)
            
            
            if len(real_images.shape) == 3:
                real_images = real_images[:, :, :, None]
                real_images = real_images.expand(real_images.shape[0], real_images.shape[1], real_images.shape[2], 3)
                
        
            if real_images.shape[1] != 3:
                real_gradient_image = real_images.permute(0, 3, 1, 2)
                
            gradient = get_gradient(discriminator, real_gradient_image, generated_image.detach(), epsilon, labels)
            gp = gradient_penalty(gradient)


            D_total_loss = ((D_real_loss + D_fake_loss) / 2) + c_lambda*gp

            D_loss_list.append(D_total_loss)

            D_total_loss.backward()
            D_optimizer.step()

            # Train generator with real labels
            G_optimizer.zero_grad()
            G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
            G_loss_list.append(G_loss)

            G_loss.backward()
            G_optimizer.step()


            ### Visualization code ###
            if cur_step % display_step == 0 and cur_step > 0:
                print('################FAKE#################')
                show_tensor_images(generated_image)
                print('################REAL#################')
                show_tensor_images(real_images.permute(0,3,1,2))


            cur_step += 1


# Save and Load the model

In [None]:
#PATHS where to save/load

#PATH_GENERATOR = "..\\resources\results\model_128x128\\model128x128_1000epochs_generator.pth"
#PATH_CRITIC = "..\\results\saved_models\model_128x128\\model128x128_1000epochs_critic.pth"

In [None]:
#Save model

#torch.save(gen.state_dict(), PATH_GENERATOR)
#torch.save(crit.state_dict(), PATH_CRITIC)

In [None]:
#Load Model

#myLoadedGenerator = Generator(z_dim).to(device)
#myLoadedGenerator.load_state_dict(torch.load(PATH_GENERATOR))

In [None]:
#Generate fake image and display it

#predictedFake = myLoadedGenerator.forward(get_noise(cur_batch_size, z_dim, device=device))
#show_tensor_images(predictedFake)
