In [None]:
# Importing the libraries
from __future__ import print_function
import torch
from torch.utils.data import ConcatDataset
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import os
from PIL import Image


os.makedirs('/kaggle/working/results')

In [None]:
cuda_checked = False
if torch.cuda.is_available() and not cuda_checked:
    device = torch.device("cuda")
    print("CUDA device:", torch.cuda.get_device_name(0))
    cuda_checked = True
else:
    device = torch.device("cpu")
    print("Using CPU for computation")



class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = self.load_samples()

    def __getitem__(self, index):
        try:
            sample_path = self.samples[index]
            sample = Image.open(sample_path).convert("RGB")  # Open image as RGB

            if self.transform is not None:
                sample = self.transform(sample)

            return sample
        except OSError as e:
            print(f"Error processing image at index {index}: {e}")
            return None

    def __len__(self):
        return len(self.samples)

    def load_samples(self):
        sample_list = []
        for root, _, filenames in os.walk(self.root):
            for filename in filenames:
                if (
                    filename.endswith(".jpeg")
                    or filename.endswith(".png")
                    or filename.endswith(".jpg")
                ):
                    sample_path = os.path.join(root, filename)
                    sample_list.append(sample_path)
                    #print(f"Loaded sample: {sample_path}")
        return sample_list





batchSize = 64 
imageSize = (64, 64) 

# Creating the transformations
transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) 

dataset1 = CustomDataset(root='/kaggle/input/metal-album-art-by-subgenre/data/black_metal', transform=transform)
dataset2 = CustomDataset(root='/kaggle/input/metal-album-art-by-subgenre/data/power_metal', transform=transform)

merged_dataset = ConcatDataset([dataset1, dataset2])
merged_length = len(merged_dataset)
print("Length of the merged dataset:", merged_length)

dataloader = torch.utils.data.DataLoader(merged_dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 

# Defining the weights_init function that takes as input a neural network m and that will initialize all its weights.
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

        
        
        
# Defining the generator
class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential( 
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False), 
            nn.BatchNorm2d(512), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(256),
            nn.ReLU(True), 
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(128), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(64), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False), 
            nn.Tanh() 
        )

    def forward(self, input): 
        output = self.main(input) 
        return output 

# Creating the generator
netG = G()
netG.apply(weights_init) 


# Defining the discriminator
class D(nn.Module):

    def __init__(self): 
        super(D, self).__init__() 
        self.main = nn.Sequential( 
            nn.Conv2d(3, 64, 4, 2, 1, bias = False), 
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(128), 
            nn.LeakyReLU(0.2, inplace = True), 
            nn.Conv2d(128, 256, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(256), # We normalize again.
            nn.LeakyReLU(0.2, inplace = True), 
            nn.Conv2d(256, 512, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(512), 
            nn.LeakyReLU(0.2, inplace = True), 
            nn.Conv2d(512, 1, 4, 1, 0, bias = False), 
            nn.Sigmoid() 
        )

    def forward(self, input): 
        output = self.main(input) 
        return output.view(-1)

# Creating the discriminator
netD = D()
netD.apply(weights_init) 

# Training the DCGANs

criterion = nn.BCELoss().to(device) 
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999)) 
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999)) 

netG = netG.to(device)
netD = netD.to(device)


for epoch in range(700): # We iterate over 700 epochs.

    for i, data in enumerate(dataloader, 0): # We iterate over the images of the dataset.
        
        # 1st Step: Updating the weights of the neural network of the discriminator

        netD.zero_grad()
        
        # Training the discriminator with a real image of the dataset
        real = data.to(device) 
        input = Variable(real).to(device) 
        target = Variable(torch.ones(input.size()[0])).to(device) 
        output = netD(input).to(device) 
        errD_real = criterion(output, target).to(device) 
        
        # Training the discriminator with a fake image generated by the generator
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device) 
        fake = netG(noise) 
        target = Variable(torch.zeros(input.size()[0])).to(device) 
        output = netD(fake.detach()) 
        errD_fake = criterion(output, target) 

        # Backpropagating the total error
        errD = errD_real + errD_fake 
        errD.backward()
        optimizerD.step()

        # 2nd Step: Updating the weights of the neural network of the generator

        netG.zero_grad() 
        target = Variable(torch.ones(input.size()[0])).to(device) 
        output = netD(fake).to(device) 
        errG = criterion(output, target).to(device) 
        errG.backward() 
        optimizerG.step() 

        #Printing the losses and saving the real images and the generated images of the minibatch every 100 steps

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 700, i, len(dataloader), errD.item(), errG.item()))
        if i % 100 == 0:
            vutils.save_image(real, '%s/real_samples.png' % "/kaggle/working/results", normalize = True)
            fake = netG(noise) 
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("/kaggle/working/results", epoch), normalize = True) 