# Task:
* Using Google Image Search, find 10 horse images and convert them to zebra images using CycleGAN (use the pre-trained grid from the book).
* Plot all images together in one plot 
* Results must be with source code in your own git repo
* The .ipynb files must run without errors and produce the outputs

# Implementation:

In [169]:
# The following code is a Python script for generating images of zebras from images of horses using a pre-trained ResNet-based generator model

In [170]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    '''
    defines a block of two convolutional layers with a shortcut connection (i.e., the input is added to the output of the second convolution)
    '''

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):
    '''
    defines a generator model that takes an image of a horse as input and outputs an image of a zebra.
    The model has several layers, including a few downsampling layers, several ResNet blocks, and a few upsampling layers.
    The input and output images have 3 channels (RGB), and the generator has 64 filters in the first convolutional layer and 9 ResNet blocks by default.
    The output image is obtained by applying a tanh activation function to the output of the last convolutional layer.
    '''

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)
    
    
netG = ResNetGenerator()

In [171]:
# loads pre-trained ResNet-based generator model from file
model_data = torch.load('horse2zebra_0.4.0.pth')
netG.load_state_dict(model_data)

<All keys matched successfully>

In [172]:
netG.eval()

ResNetGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): ResNetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
     

In [173]:
from PIL import Image
from torchvision import transforms

In [174]:
# loads images of horses from a directory and save them in a list
import os

horse_image_list = []
directory = "horse_images/"
for i in range(10):
    filename = "horse" + str(i) + ".jpg"
    filepath = os.path.join(directory, filename)
    try:
        img = Image.open(filepath)
        horse_image_list.append(img)
    except:
        print("Could not load image at path:", filepath)


In [175]:
# preprocesses the images (resizes them to 256x256 pixels and converts them to PyTorch tensors)
# and creates a list of batches of images.
preprocess = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])

batch_t_list = []
for img in horse_image_list:
    img_t = preprocess(img)
    batch_t = torch.unsqueeze(img_t, 0)
    batch_t_list.append(batch_t)

In [176]:
# applies the generator model to each batch of images and converts the output tensors to PIL images. The resulting images are added to a list
out_img_list = []
for batch_t in batch_t_list:
    batch_out = netG(batch_t)
    out_t = (batch_out.data.squeeze() + 1.0) / 2.0
    out_img = transforms.ToPILImage()(out_t)
    out_img_list.append(out_img)

In [178]:
# plot the original and new outcome images in subplot
import matplotlib.pyplot as plt
import numpy as np

# create subplot grid
fig, axs = plt.subplots(nrows=10, ncols=2, figsize=(10, 50))
fig.subplots_adjust(hspace=0.5, wspace=0.2)
fig.suptitle("Horse2Zebra Converter", y=0.9)

# Bilder und Pfeile plotten
for i in range(10):
    # Plot horse image
    axs[i, 0].imshow(horse_image_list[i])
    axs[i, 0].axis('off')
    axs[i, 0].axes.set_title('Horse')
    
    # Plot output image
    axs[i, 1].imshow(out_img_list[i])
    axs[i, 1].axis('off')
    axs[i, 1].axes.set_title('Zebra')

    # Plot arrow from horse image to output image
    x_start = 1.0 # x-coordinate of arrow starting point
    y_start = 0.5 # y-coordinate of arrow starting point
    x_end = 1.2 # x-coordinate of arrow ending point
    y_end = 0.5 # y-coordinate of arrow ending point
    axs[i, 0].annotate('', xy=(x_start, y_start), xytext=(x_end, y_end),
                        xycoords='axes fraction', textcoords='axes fraction',
                        arrowprops=dict(facecolor='black', arrowstyle='<|-', connectionstyle='arc3'))
# show plot and save result
plt.savefig('result_horse2zebra.png', bbox_inches='tight', pad_inches=0.1)
plt.show()


SyntaxError: invalid syntax (3214005345.py, line 31)