In [25]:

from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from os import listdir

from numpy import vstack

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

from qartezator.data.dataset import QartezatorDataset
from qartezator.data.datamodule import QartezatorDataModule
from qartezator.data.transforms import get_common_augmentations
from qartezator.data.datautils import load_image, pad_img_to_modulo
import os
import torch.nn.init as init
import random


In [26]:


def define_discriminator(image_shape):
    # Weight initialization
    def weights_init_uniform(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                init.constant_(m.bias.data, 0.0)


    # Define the discriminator network
    model = nn.Sequential(
        # C64: 4x4 kernel Stride 2x2
        nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        # C128: 4x4 kernel Stride 2x2
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        # C256: 4x4 kernel Stride 2x2
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
        # C512: 4x4 kernel Stride 2x2
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        # Second last output layer: 4x4 kernel Stride 1x1
        nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        # Patch output
        nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
    )

    # Initialize the weights
    model.apply(weights_init_uniform)

    return model


In [27]:

def resnet_block(n_filters, input_layer):
    # Weight initialization
    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("InstanceNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    # Define the residual block
    model = nn.Sequential(
        # First convolutional layer
        nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
        nn.InstanceNorm2d(n_filters),
        nn.ReLU(inplace=True),
        # Second convolutional layer
        nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
        nn.InstanceNorm2d(n_filters)
    )

    # Initialize the weights
    model.apply(weights_init_normal)

    # Concatenate merge channel-wise with input layer
    return nn.ReLU(inplace=True)(torch.cat([model(input_layer), input_layer], dim=1))


In [28]:


def define_generator(image_shape, n_resnet=6):
    # Weight initialization
    def weights_init_normal(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                init.constant_(m.bias.data, 0.0)


    # Define the residual block
    class ResidualBlock(nn.Module):
        def __init__(self, n_filters):
            super(ResidualBlock, self).__init__()

            self.conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1)
            self.norm1 = nn.InstanceNorm2d(n_filters)
            self.conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1)
            self.norm2 = nn.InstanceNorm2d(n_filters)

        def forward(self, x):
            residual = x

            out = F.relu(self.norm1(self.conv1(x)))
            out = self.norm2(self.conv2(out))

            out = out + residual

            return out

    # Define the generator network
    class Generator(nn.Module):
        def __init__(self, image_shape, n_resnet):
            super(Generator, self).__init__()

            self.c7s1_64 = nn.Sequential(
                nn.ReflectionPad2d(3),
                nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=0),
                nn.InstanceNorm2d(64),
                nn.ReLU(inplace=True)
            )

            self.d128 = nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(128),
                nn.ReLU(inplace=True)
            )

            self.d256 = nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(256),
                nn.ReLU(inplace=True)
            )

            self.resnet_blocks = nn.ModuleList([ResidualBlock(256) for _ in range(n_resnet)])

            self.u128 = nn.Sequential(
                nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(128),
                nn.ReLU(inplace=True)
            )

            self.u64 = nn.Sequential(
                nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(64),
                nn.ReLU(inplace=True)
            )

            self.c7s1_3 = nn.Sequential(
                nn.ReflectionPad2d(3),
                nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=0),
                nn.InstanceNorm2d(3),
                nn.Tanh()
            )

        def forward(self, x):
            out = self.c7s1_64(x)
            out = self.d128(out)
            out = self.d256(out)

            for resnet_block in self.resnet_blocks:
                out = resnet_block(out)

            out = self.u128(out)
            out = self.u64(out)
            out = self.c7s1_3(out)

            return out

    # Create an instance of the generator
    generator = Generator(image_shape, n_resnet)

    # Initialize the weights
    generator.apply(weights_init_normal)

    return generator


In [29]:


def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
    # Make the generator of interest trainable as we will be updating these weights.
    # by keeping other models constant.
    # Remember that we use this same function to train both generators,
    # one generator at a time.
    g_model_1.trainable = True
    # mark discriminator and second generator as non-trainable
    d_model.trainable = False
    g_model_2.trainable = False

    # Define the composite model
    class CompositeModel(nn.Module):
        def __init__(self, g_model_1, d_model, g_model_2):
            super(CompositeModel, self).__init__()

            self.g_model_1 = g_model_1
            self.d_model = d_model
            self.g_model_2 = g_model_2

        def forward(self, input_gen, input_id):
            gen1_out = self.g_model_1(input_gen)
            output_d = self.d_model(gen1_out)

            output_id = self.g_model_1(input_id)

            output_f = self.g_model_2(gen1_out)

            gen2_out = self.g_model_2(input_id)
            output_b = self.g_model_1(gen2_out)

            return output_d, output_id, output_f, output_b

    # Create an instance of the composite model
    composite_model = CompositeModel(g_model_1, d_model, g_model_2)

    # Define the optimizer
    opt = optim.Adam(composite_model.parameters(), lr=0.0001, betas=(0.5, 0.999))

    return composite_model, opt


In [30]:
def generate_real_samples(dataset, n_samples, patch_shape):
    # Choose random instances
    ix = torch.randint(0, dataset.shape[0], (n_samples,))
    # Retrieve selected images
    X = dataset[ix]
    # Generate 'real' class labels (1)
    y = torch.ones((n_samples, 1, patch_shape, patch_shape))
    return X, y


In [31]:
def generate_fake_samples(g_model, dataset, patch_shape):
    # Set the generator model to evaluation mode
    g_model.eval()

    # Convert the dataset to a PyTorch tensor
    dataset_tensor = torch.from_numpy(dataset).float()

    # Generate fake images
    with torch.no_grad():
        X = g_model(dataset_tensor)

    # Create 'fake' class labels (0)
    batch_size = len(X)
    y = torch.zeros(batch_size, 1, patch_shape, patch_shape)

    return X, y

In [32]:
def save_models(step, g_model_AtoB, g_model_BtoA):
    # Save the first generator model
    filename1 = 'g_model_AtoB_%06d.pth' % (step+1)
    torch.save(g_model_AtoB.state_dict(), filename1)
    # Save the second generator model
    filename2 = 'g_model_BtoA_%06d.pth' % (step+1)
    torch.save(g_model_BtoA.state_dict(), filename2)
    print('> Saved: %s and %s' % (filename1, filename2))

In [33]:

def summarize_performance(step, g_model, trainX, name, n_samples=5):
    # Select a sample of input images
    X_in, _ = generate_real_samples(trainX, n_samples, 0)
    # Generate translated images
    X_out, _ = generate_fake_samples(g_model, X_in, 0)
    # Scale all pixels from [-1,1] to [0,1]
    X_in = (X_in + 1) / 2.0
    X_out = (X_out + 1) / 2.0

    # Convert tensors to numpy arrays
    X_in = X_in.permute(0, 2, 3, 1).numpy()
    X_out = X_out.permute(0, 2, 3, 1).numpy()

    # Plot real images
    fig, axes = plt.subplots(2, n_samples, figsize=(10, 4))
    for i in range(n_samples):
        axes[0, i].axis('off')
        axes[0, i].imshow(X_in[i])

    # Plot translated images
    for i in range(n_samples):
        axes[1, i].axis('off')
        axes[1, i].imshow(X_out[i])

    # Save plot to file
    filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
    plt.savefig(filename1)
    plt.close()


In [34]:
def update_image_pool(pool, images, max_size=50):
    selected = []
    for image in images:
        if len(pool) < max_size:
            # Stock the pool
            pool.append(image)
            selected.append(image)
        elif random.random() < 0.5:
            # Use image, but don't add it to the pool
            selected.append(image)
        else:
            # Replace an existing image and use replaced image
            ix = random.randint(0, len(pool) - 1)
            selected.append(pool[ix])
            pool[ix] = image
    return np.asarray(selected)


In [35]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [46]:


def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, train_dataloader, epochs=1):
    # Define properties of the training run
    n_epochs= epochs
    n_batch=32
    # Determine the output square shape of the discriminator
    n_patch = 14
    #input_shape = dataset[0][0].shape
    #dummy_input = torch.rand(n_batch, *input_shape)
    #n_patch = d_model_A(dummy_input).shape[2]
    # Unpack dataset
    #trainA, trainB = train_data
    # Prepare image pool for fake images
    poolA, poolB = [], []
    # Calculate the number of batches per training epoch
    #bat_per_epo = len(trainA) // n_batch
    # Calculate the number of training iterations
    #n_steps = bat_per_epo * n_epochs

    # Define loss functions
    adversarial_loss = nn.MSELoss()
    identity_loss = nn.L1Loss()
    cycle_loss = nn.L1Loss()

    # Define optimizers
    optimizer_c_AtoB = optim.Adam(c_model_AtoB.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_c_BtoA = optim.Adam(c_model_BtoA.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_d_A = optim.Adam(d_model_A.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_d_B = optim.Adam(d_model_B.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_g_AtoB = optim.Adam(g_model_AtoB.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_g_BtoA = optim.Adam(g_model_BtoA.parameters(), lr=0.0001, betas=(0.5, 0.999))

    # Create data loaders
    #dataloader_A = DataLoader(trainA, batch_size=n_batch, shuffle=True)
    #dataloader_B = DataLoader(trainB, batch_size=n_batch, shuffle=True)

    # Manually enumerate epochs
    for epoch in range(n_epochs):
        # Enumerate over the data loaders
        for i, (real_B,real_A) in enumerate(train_dataloader):
            # Move real images to the device
            real_A=(real_A-0.5)/0.5
            real_B=(real_B-0.5)/0.5
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            

            ##############################
            # Update AtoB Generator
            ##############################

            # Set generators' gradients to zero
            optimizer_g_AtoB.zero_grad()
            optimizer_c_AtoB.zero_grad()

            # Generate fake images
            fake_B = g_model_AtoB(real_A)
            fake_A = g_model_BtoA(real_B)
            
            # Update image pool for fake images
            fake_A = update_image_pool(poolA, fake_A.detach().cpu().numpy())
            fake_B = update_image_pool(poolB, fake_B.detach().cpu().numpy())
            fake_A_tensor = torch.from_numpy(fake_A).to(device)
            fake_B_tensor = torch.from_numpy(fake_B).to(device)

            # Adversarial loss
            
            pred_fake_A = d_model_A(fake_A_tensor)
            pred_fake_B = d_model_B(fake_B_tensor)
            loss_adv_AtoB = adversarial_loss(pred_fake_A, torch.ones_like(pred_fake_A))
            loss_adv_BtoA = adversarial_loss(pred_fake_B, torch.ones_like(pred_fake_B))

            # Identity loss
            idt_A = g_model_BtoA(real_A)
            idt_B = g_model_AtoB(real_B)
            loss_idt_A = identity_loss(idt_A, real_A)
            loss_idt_B = identity_loss(idt_B, real_B)

            # Cycle-consistency loss
            
            cycle_A = g_model_BtoA(fake_B_tensor)
            cycle_B = g_model_AtoB(fake_A_tensor)
            loss_cycle_A = cycle_loss(cycle_A, real_A)
            loss_cycle_B = cycle_loss(cycle_B, real_B)

            # Total loss
            loss_g_AtoB = loss_adv_AtoB + 5 * loss_idt_A + 10 * loss_cycle_A
            loss_g_BtoA = loss_adv_BtoA + 5 * loss_idt_B + 10 * loss_cycle_B

            # Backpropagation and optimization
            loss_g_AtoB.backward()
            optimizer_g_AtoB.step()
            optimizer_c_AtoB.step()

            ##############################
            # Update BtoA Generator
            ##############################

            # Set generators' gradients to zero
            optimizer_g_BtoA.zero_grad()
            optimizer_c_BtoA.zero_grad()

            # Generate fake images
            fake_A = g_model_BtoA(real_B)
            fake_B = g_model_AtoB(real_A)

            # Update image pool for fake images
            fake_B = update_image_pool(poolB, fake_B.detach().cpu().numpy())
            fake_A = update_image_pool(poolA, fake_A.detach().cpu().numpy())
            fake_A_tensor = torch.from_numpy(fake_A).to(device)
            fake_B_tensor = torch.from_numpy(fake_B).to(device)

            # Adversarial loss
            #fake_A_tensor = torch.from_numpy(fake_A)
            #fake_B_tensor = torch.from_numpy(fake_B)
            pred_fake_B = d_model_B(fake_B_tensor)
            pred_fake_A = d_model_A(fake_A_tensor)
            loss_adv_BtoA = adversarial_loss(pred_fake_B, torch.ones_like(pred_fake_B))
            loss_adv_AtoB = adversarial_loss(pred_fake_A, torch.ones_like(pred_fake_A))

            # Identity loss
            idt_B = g_model_AtoB(real_B)
            idt_A = g_model_BtoA(real_A)
            loss_idt_B = identity_loss(idt_B, real_B)
            loss_idt_A = identity_loss(idt_A, real_A)

            # Cycle-consistency loss
            cycle_B = g_model_AtoB(fake_A_tensor)
            cycle_A = g_model_BtoA(fake_B_tensor)
            loss_cycle_B = cycle_loss(cycle_B, real_B)
            loss_cycle_A = cycle_loss(cycle_A, real_A)

            # Total loss
            loss_g_BtoA = loss_adv_BtoA + 5 * loss_idt_B + 10 * loss_cycle_B
            loss_g_AtoB = loss_adv_AtoB + 5 * loss_idt_A + 10 * loss_cycle_A

            # Backpropagation and optimization
            loss_g_BtoA.backward()
            optimizer_g_BtoA.step()
            optimizer_c_BtoA.step()

            ##############################
            # Update Discriminators
            ##############################

            # Set discriminators' gradients to zero
            optimizer_d_A.zero_grad()
            optimizer_d_B.zero_grad()

            # Real loss
            pred_real_A = d_model_A(real_A)
            pred_real_B = d_model_B(real_B)
            loss_real_A = adversarial_loss(pred_real_A, torch.ones_like(pred_real_A))
            loss_real_B = adversarial_loss(pred_real_B, torch.ones_like(pred_real_B))

            # Fake loss
            pred_fake_A = d_model_A(fake_A_tensor.detach())
            pred_fake_B = d_model_B(fake_B_tensor.detach())
            loss_fake_A = adversarial_loss(pred_fake_A, torch.zeros_like(pred_fake_A))
            loss_fake_B = adversarial_loss(pred_fake_B, torch.zeros_like(pred_fake_B))

            # Total loss
            loss_d_A = (loss_real_A + loss_fake_A) * 0.5
            loss_d_B = (loss_real_B + loss_fake_B) * 0.5

            # Backpropagation and optimization
            loss_d_A.backward()
            optimizer_d_A.step()

            loss_d_B.backward()
            optimizer_d_B.step()

            ##############################
            # Summarize Performance
            ##############################
            #if (i + 1) % (bat_per_epo * 5) == 0:
                # Save the models
                #save_models(i, g_model_AtoB, g_model_BtoA)

            
        print(f"Epoch [{epoch+1}/{n_epochs}] | Generator Loss: {loss_g_BtoA} | Discriminator Loss: {loss_g_AtoB}")
                # Plot A->B translation
                #trainA_tensor = torch.from_numpy(trainA.numpy())
                #trainB_tensor = torch.from_numpy(trainB.numpy())
                # Plot A->B translation
                #summarize_performance(i, g_model_AtoB, trainA_tensor.permute(1, 3, 256, 256), 'AtoB')
    
    # Plot B->A translation
                #summarize_performance(i, g_model_BtoA, trainB_tensor.permute(1, 3, 256, 256), 'BtoA')

            


In [37]:
root_path = './data/maps'
train_txt_path = './assets/train.txt'
val_txt_path = './assets/val.txt'
test_txt_path = './assets/test.txt'

In [38]:
new_directory = r'C:\Users\User\Desktop\qartezator'
os.chdir(new_directory)

In [39]:
ds = QartezatorDataset(
    root_path=root_path,
    split_file_path=train_txt_path,
    common_transform=get_common_augmentations(256)
) 

In [40]:
ds_test = QartezatorDataset(
    root_path=root_path,
    split_file_path=test_txt_path,
    common_transform=get_common_augmentations(256)
) 

In [42]:
dm = QartezatorDataModule(
    root_path=root_path,
    train_txt_path=train_txt_path,
    val_txt_path=val_txt_path,
    test_txt_path=test_txt_path,
    input_size=256
)
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()
test_dataloader = dm.test_dataloader()

In [43]:
for batch in train_dataloader:
    source, target = batch
    print(f'Source batch shape: {source.shape}')
    print(f'Target batch shape: {target.shape}\n')
    break

Source batch shape: torch.Size([32, 3, 256, 256])
Target batch shape: torch.Size([32, 3, 256, 256])



In [44]:
for batch in test_dataloader:
    source_t, target_t = batch
    print(f'Source batch shape: {source_t.shape}')
    print(f'Target batch shape: {target_t.shape}\n')
    break

Source batch shape: torch.Size([32, 3, 608, 608])
Target batch shape: torch.Size([32, 3, 608, 608])



In [47]:
image_shape = source.shape[1:]

In [48]:
image_shape

torch.Size([3, 256, 256])

In [49]:
g_model_AtoB = define_generator(image_shape)

In [50]:
# generator: B -> A
g_model_BtoA = define_generator(image_shape)

In [51]:
# discriminator: A -> [real/fake]
d_model_A = define_discriminator(image_shape)

In [52]:
# discriminator: B -> [real/fake]
d_model_B = define_discriminator(image_shape)

In [53]:
# composite: A -> B -> [real/fake, A]
c_model_AtoB, optimizer_c_AtoB  = define_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)
# composite: B -> A -> [real/fake, B]
c_model_BtoA,optimizer_c_BtoA = define_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)

In [None]:
from datetime import datetime 
start1 = datetime.now() 
# train models
train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, train_dataloader, epochs=5)

stop1 = datetime.now()
#Execution time of the model 
execution_time = stop1-start1
print("Execution time is: ", execution_time)