In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import animation 
from matplotlib import colors
from PIL import Image

from tqdm import tqdm

import os

from IPython.display import HTML

Google Drive Specific Commands

In [None]:
!pip install pytorch_msssim

In [None]:
from google.colab import drive
drive.mount("/content/drive/")

In [None]:
!unzip "/content/drive/MyDrive/CMPUT652_PCGML/data/data.zip" -d "/content/"

In [None]:
!unzip "/content/drive/MyDrive/CMPUT652_PCGML/data/fusions.zip" -d "/content/fusions/"

Back to Normal

In [None]:
from pytorch_msssim import ssim

# Config

In [None]:
learning_rate = 1e-4
epochs = 10
batch_size = 64

experiment_name = f"finetune_fusion_autoencoder_v2"

num_layers = 4
max_filters = 512
image_size = 64
latent_dim = 2048
use_noise_images = True
small_conv = True # To use the 1x1 convolution layer

# Fusion Parameters
fusion_mode = "both" # encoder, decoder, both
pretrained_model_path = "/content/drive/MyDrive/CMPUT652_PCGML/outputs/convolutional_autoencoder_v8.1/model.pt"
freeze_conv = True
unfreeze_epoch = 100

data_prefix = "/content"
train_data_folder = data_prefix + "/train/"
val_data_folder = data_prefix + "/val/"
test_data_folder = data_prefix + "/test/"

fusion_data_prefix = "/content/fusions"
train_fusion_data_folder = fusion_data_prefix + "/train/"
val_fusion_data_folder = fusion_data_prefix + "/val/"
test_fusion_data_folder = fusion_data_prefix + "/test/"

output_prefix = f"/content/drive/MyDrive/CMPUT652_PCGML/outputs/{experiment_name}"
output_dir = os.path.join(output_prefix, "generated", "normal")
fusion_output_dir = os.path.join(output_prefix, "generated", "fusions")
model_output_path = os.path.join(output_prefix, "model.pt")
animation_output_path = os.path.join(output_prefix, "animation.mp4")
fusion_animation_output_path = os.path.join(output_prefix, "fusion_animation.mp4")
loss_output_path = os.path.join(output_prefix, "loss.jpg")
fusion_loss_output_path = os.path.join(output_prefix, "fusion_loss.jpg")

In [None]:
gpu = torch.cuda.is_available()
device = torch.device("cuda" if gpu else "cpu")

In [None]:
print(gpu, device)

In [None]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(fusion_output_dir):
    os.makedirs(fusion_output_dir)

In [None]:
seed = 42
np.random.seed(seed)
_ = torch.manual_seed(seed)

# Load Dataset

In [None]:
def load_images_from_folder(folder):
    dataset = {}
    for file in os.listdir(folder):
        if "noise" in file and not use_noise_images:
            continue
        image = Image.open(os.path.join(folder, file))
        dataset[file] = np.array(image)
    print(f"Loaded {len(dataset)} images.")
    return dataset

In [None]:
train = load_images_from_folder(train_data_folder)

In [None]:
val = load_images_from_folder(val_data_folder)

In [None]:
test = load_images_from_folder(test_data_folder)

# Visualize Some Examples

In [None]:
data = {
    "train": train,
    "test": test,
    "val": val
}

In [None]:
for i, dataset in enumerate(data.keys()):
    keys = list(data[dataset].keys())
    for j in range(2):
        index = np.random.randint(0, len(data[dataset]))
        axes = plt.subplot(2, 3, i+j*3+1)
        plt.imshow(data[dataset][keys[index]])
        axes.set_title(f"{dataset} ({index})")
plt.tight_layout()

# Make Datasets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
])

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = list(dataset.values())
        self.keys = list(dataset.keys())
        self.transform = transform
        
    def __getitem__(self, index):
        data = self.dataset[index]
        key = self.keys[index]
        if self.transform is not None:
            data = self.transform(data)
        return key, data

    def __len__(self):
        return len(self.dataset)

    def get_mapping(self):
        return {key: i for i, key in enumerate(self.keys)}

In [None]:
train_data = CustomDataset(train, transform)
val_data = CustomDataset(val, transform)
test_data = CustomDataset(test, transform)

# Load Fusions

In [None]:
class FusionDataset(torch.utils.data.Dataset):
    def __init__(self, fusion_dataset_path, base_train_images, base_val_images, base_test_images, transform):
        self.dataset_path = fusion_dataset_path
        self.all_images = os.listdir(fusion_dataset_path)
        self.transform = transform

        self.base_dataset = [base_train_images, base_val_images, base_test_images]

    def to_3_digit(self, num):
        return "0" * (3 - len(num)) + num

    def get_base_image(self, num, background):
        # if Base BW is not there, search for female BW
        filenames = [
            f"{num}_base_bw_{background}_0rotation.png",
            f"{num}_base_female_bw_{background}_0rotation.png"
        ]
        for dataset in self.base_dataset:
            for filename in filenames:
                if filename in dataset:
                    return dataset[filename], filename
        return None, None
        
    def __getitem__(self, index):
        fusion_filename = self.all_images[index]
        # Get two base names
        first, second, background, _ = fusion_filename.split('.')
        first = self.to_3_digit(first)
        second = self.to_3_digit(second)
        # Get Base
        image, base_filename = self.get_base_image(first, background)
        base = self.transform(image)
        # Get Fusee
        image, fusee_filename = self.get_base_image(second, background)
        fusee = self.transform(image)
        # Get Fusion
        fusion_loc = os.path.join(self.dataset_path, fusion_filename)
        image = Image.open(fusion_loc).convert("RGB")
        fusion = self.transform(image)
        return (base_filename, fusee_filename, fusion_filename), (base, fusee, fusion)

    def __len__(self):
        return len(self.all_images)

In [None]:
train_fusions = FusionDataset(train_fusion_data_folder, train, val, test, transform)
val_fusions = FusionDataset(val_fusion_data_folder, train, val, test, transform)
test_fusions = FusionDataset(test_fusion_data_folder, train, val, test, transform)

Make Dataloaders

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)
val_dataloader = torch.utils.data.DataLoader(
    val_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)
test_dataloader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)

In [None]:
train_fusion_dataloader = torch.utils.data.DataLoader(
    train_fusions, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)
val_fusion_dataloader = torch.utils.data.DataLoader(
    val_fusions, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)
test_fusion_dataloader = torch.utils.data.DataLoader(
    test_fusions, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=gpu
)

# Visualize Again

In [None]:
def make_grid(images, height, width, axis):
    i, j = 0, 0
    text, images = images
    for num, image in enumerate(images):
        if num == height * width:
            break
        axis[i,j].imshow(np.clip(np.asarray(image.permute(1, 2, 0)), 0, 1))
        if j == width - 1:
            j = 0
            i += 1
        else:
            j += 1
    if type(text) == int:
        text = f"Epoch: {text}"
    fig.suptitle(text, va="baseline")
    plt.tight_layout()
    return axis

In [None]:
def get_samples_from_data(data, sample_size, fusion=False):
    sample = []
    for i in np.random.choice(len(data), size=sample_size, replace=False):
        if fusion:
            sample.append([np.asarray(x) for x in data[i][1]])
        else:
            sample.append(np.asarray(data[i][1]))
    return torch.as_tensor(sample)

In [None]:
# Creating a sample set that we visualize every epoch to show the model's training
sample = get_samples_from_data(val_data, 16, fusion=False)
test_sample = get_samples_from_data(test_data, 16, fusion=False)
fusion_sample = get_samples_from_data(val_fusions, 4, fusion=True)
fusion_test_sample = get_samples_from_data(test_fusions, 4, fusion=True)

In [None]:
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Sample", sample), 4, 4, axis)

In [None]:
fig, axis = plt.subplots(4, 3, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Fusion Sample", [x for y in fusion_sample for x in y]), 4, 3, axis)

# Model Time

In [None]:
# Ref: https://github.com/sksq96/pytorch-vae/blob/master/vae-cnn.ipynb
class ConvolutionalAE(nn.Module):
    def __init__(self, image_channels=3, max_filters=512, num_layers=4, kernel_size=2, stride=2, 
                 padding=0, latent_dim=128, input_image_dimensions=96, small_conv=False):
        super(ConvolutionalAE, self).__init__()
        if small_conv:
            num_layers += 1
        channel_sizes = self.calculate_channel_sizes(image_channels, max_filters, num_layers)

        # Encoder
        encoder_layers = nn.ModuleList()
        # Encoder Convolutions
        for i, (in_channels, out_channels) in enumerate(channel_sizes):
            if small_conv and i == 0:
                # 1x1 Convolution
                encoder_layers.append(
                    nn.Conv2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1, 
                        stride=1, padding=0
                    )
                )
            else:
                # Convolutional Layer
                encoder_layers.append(
                    nn.Conv2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                        stride=stride, padding=padding, bias=False
                    )
                )
            # Batch Norm
            encoder_layers.append(nn.BatchNorm2d(out_channels))
            # ReLU
            encoder_layers.append(nn.ReLU())
        # Flatten Encoder Output
        encoder_layers.append(nn.Flatten())
        
        # Calculate shape of the flattened image
        hidden_dim, image_size = self.get_flattened_size(input_image_dimensions, encoder_layers)

        # Hidden Dim -> Latent Dim
        encoder_layers.append(nn.Linear(hidden_dim, latent_dim))
        encoder_layers.append(nn.Sigmoid())
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Decoder
        decoder_layers = nn.ModuleList()
        # Latent Dim -> Hidden Dim
        decoder_layers.append(nn.Linear(latent_dim, hidden_dim))
        decoder_layers.append(nn.Sigmoid())
        # Unflatten to a shape of (Channels, Height, Width)
        decoder_layers.append(nn.Unflatten(1, (max_filters, image_size, image_size)))
        # Decoder Convolutions
        for i, (out_channels, in_channels) in enumerate(channel_sizes[::-1]):
            if small_conv and i == num_layers - 1:
                # 1x1 Transposed Convolution
                decoder_layers.append(
                    nn.ConvTranspose2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1, 
                        stride=1, padding=0
                    )
                )
            else:
                # Add Transposed Convolutional Layer
                decoder_layers.append(
                    nn.ConvTranspose2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                        stride=stride, padding=padding, bias=False
                    )
                )
            # Batch Norm
            decoder_layers.append(nn.BatchNorm2d(out_channels))
            # ReLU if not final layer
            if i != num_layers - 1:
                decoder_layers.append(nn.ReLU())
            # Sigmoid if final layer
            else:
                decoder_layers.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*decoder_layers) 
        
    def calculate_layer_size(self, input_size, kernel_size, stride, padding=0):
        numerator = input_size - kernel_size + (2 * padding)
        denominator = stride
        return (numerator // denominator) + 1
        
    def get_flattened_size(self, image_size, encoder_layers):
        for layer in encoder_layers:
            if "Conv2d" in str(layer):
                kernel_size = layer.kernel_size[0]
                stride = layer.stride[0]
                padding = layer.padding[0]
                filters = layer.out_channels
                image_size = self.calculate_layer_size(image_size, kernel_size, stride, padding)
        return filters * image_size * image_size, image_size
    
    def calculate_channel_sizes(self, image_channels, max_filters, num_layers):
        channel_sizes = [(image_channels, max_filters // np.power(2, num_layers - 1))]
        for i in range(1, num_layers):
            prev = channel_sizes[-1][-1]
            new = prev * 2
            channel_sizes.append((prev, new))
        return channel_sizes
        
    def forward(self, x):
        # Encode
        hidden_state = self.encoder(x)
        # Decode
        reconstructed = self.decoder(hidden_state)
        return reconstructed

# Training Time

In [None]:
model = ConvolutionalAE(max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size, latent_dim=latent_dim, small_conv=small_conv)
model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
model.to(device)

In [None]:
def get_freezable_layers(model):
    # Freeze Conv Layers
    freezable_layers = []
    for layer in model.encoder:
        if "Linear" not in str(layer):
            freezable_layers.append(layer)
    for layer in model.decoder:
        if "Linear" not in str(layer):
            freezable_layers.append(layer)
    return freezable_layers

In [None]:
def toggle_layer_freezing(layers, trainable):
    for layer in layers:
        layer.requires_grad_(trainable)

In [None]:
# Freeze Conv Layers
if freeze_conv:
    freezable_layers = get_freezable_layers(model)
    toggle_layer_freezing(freezable_layers, trainable=False)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss(reduction='mean')

In [None]:
all_samples = []
all_fusion_samples = []
all_train_fusion_loss = []
all_val_loss = []
all_val_fusion_loss = []

for epoch in range(epochs):
    if freeze_conv and epoch == unfreeze_epoch:
        toggle_layer_freezing(freezable_layers, trainable=True)

    val_loss = 0
    train_fusion_loss = 0
    val_fusion_loss = 0

    # Training Loop - Fusions
    for iteration, batch in enumerate(tqdm(train_fusion_dataloader)):
        # Reset gradients back to zero for this iteration
        optimizer.zero_grad()

        # Move batch to device
        _, (base, fusee, fusion) = batch # Returns key, value for each Pokemon
        base = base.to(device)
        fusee = fusee.to(device)
        fusion = fusion.to(device)

        with torch.no_grad():
            # Get Encoder Output
            base_embedding = model.encoder(base)
            fusee_embedding = model.encoder(fusee)
            # Midpoint Embedding
            midpoint_embedding = (base_embedding * 0.4) + (fusee_embedding * 0.6)

        if fusion_mode == "encoder" or fusion_mode == "both":
            # Run our model & get outputs
            fusion_embedding = model.encoder(fusion)
            # Calculate reconstruction loss: Midpoint Embedding vs Fusion Embedding
            batch_loss = criterion(midpoint_embedding, fusion_embedding)

        if fusion_mode == "decoder" or fusion_mode == "both":
            # Run our model & get outputs
            fusion_output = model.decoder(midpoint_embedding)   
            # Calculate reconstruction loss: Midpoint Output vs Original Fusion
            batch_loss = criterion(fusion, fusion_output)
            # Backprop
            batch_loss.backward()
            # Add the batch's loss to the total loss for the epoch
            train_fusion_loss += batch_loss.item()

        # Update our optimizer parameters
        # We call it out here instead of inside the if because
        # the gradients are accumulated in case both conditions are true
        optimizer.step()

    # Validation Loop - Standard
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(val_dataloader)):
            # Move batch to device
            _, batch = batch # Returns key, value for each Pokemon
            batch = batch.to(device)

            # Run our model & get outputs
            reconstructed = model(batch)

            # Calculate reconstruction loss
            batch_loss = criterion(batch, reconstructed)

            # Add the batch's loss to the total loss for the epoch
            val_loss += batch_loss.item()

    # Validation Loop - Fusions
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(val_fusion_dataloader)):
            # Move batch to device
            _, (base, fusee, fusion) = batch # Returns key, value for each Pokemon
            base = base.to(device)
            fusee = fusee.to(device)
            fusion = fusion.to(device)

            # Get Encoder Output
            base_embedding = model.encoder(base)
            fusee_embedding = model.encoder(fusee)
            # Midpoint Embedding
            midpoint_embedding = (base_embedding * 0.4) + (fusee_embedding * 0.6)

            if fusion_mode == "encoder" or fusion_mode == "both":
                # Run our model & get outputs
                fusion_embedding = model.encoder(fusion)
                # Calculate reconstruction loss: Midpoint Embedding vs Fusion Embedding
                batch_loss = criterion(midpoint_embedding, fusion_embedding)
                # Add the batch's loss to the total loss for the epoch
                val_fusion_loss += batch_loss.item()

            if fusion_mode == "decoder" or fusion_mode == "both":
                # Run our model & get outputs
                fusion_output = model.decoder(midpoint_embedding)
                # Calculate reconstruction loss: Midpoint Output vs Original Fusion
                batch_loss = criterion(fusion, fusion_output)
                # Add the batch's loss to the total loss for the epoch
                val_fusion_loss += batch_loss.item()
            
    # Get Sample Outputs for the animation
    with torch.no_grad():
        # Get reconstruction of our normal Pokemon
        epoch_sample = model(sample.to(device))

        # Get example fusions
        fusion_sample_base, fusion_sample_fusee, fusion_sample_fusion = fusion_sample[:, 0], fusion_sample[:, 1], fusion_sample[:, 2]
        # Sample Fusion
        fusion_sample_midpoint_embedding = (model.encoder(fusion_sample_base.to(device)) * 0.4) + (model.encoder(fusion_sample_fusee.to(device)) * 0.6)
        fusion_sample_fusion = model.decoder(fusion_sample_midpoint_embedding.to(device))
        # Sample Base Images
        fusion_sample_base = model(fusion_sample_base.to(device))
        fusion_sample_fusee = model(fusion_sample_fusee.to(device))
        fusion_epoch_sample = torch.stack((fusion_sample_base, fusion_sample_fusee, fusion_sample_fusion), dim=1).flatten(end_dim=1)

    # Add sample reconstruction to our list
    all_samples.append(epoch_sample.detach().cpu())
    all_fusion_samples.append(fusion_epoch_sample.detach().cpu())
    
    # Compute the average losses for this epoch
    train_fusion_loss = train_fusion_loss / len(train_fusion_dataloader)
    all_train_fusion_loss.append(train_fusion_loss)
    
    val_loss = val_loss / len(val_dataloader)
    all_val_loss.append(val_loss)

    val_fusion_loss = val_fusion_loss / len(val_fusion_dataloader)
    all_val_fusion_loss.append(val_fusion_loss)
    
    # Print Metrics
    print(
        f"\nEpoch: {epoch+1}/{epochs}:\
        \nVal Loss = {val_loss}\
        \nTrain Fusion Loss = {train_fusion_loss}\
        \nVal Fusion Loss = {val_fusion_loss}"
    )

# Visualize Training

In [None]:
# Plot Original Image
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Sample", sample), 4, 4, axis)

In [None]:
    fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
plt.close()
anim = animation.FuncAnimation(fig=fig, func=make_grid, frames=list(enumerate(all_samples)), 
                               fargs=(4, 4, axis), interval=100, repeat=False)

In [None]:
HTML(anim.to_html5_video())

In [None]:
fig, axis = plt.subplots(4, 3, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Fusion Sample", [x for y in fusion_sample for x in y]), 4, 3, axis)

In [None]:
fig, axis = plt.subplots(4, 3, figsize=(8, 6), dpi=80)
plt.tight_layout()
plt.close()
fusion_anim = animation.FuncAnimation(fig=fig, func=make_grid, frames=list(enumerate(all_fusion_samples)), 
                               fargs=(4, 3, axis), interval=100, repeat=False)

In [None]:
HTML(fusion_anim.to_html5_video())

# Evaluation

In [None]:
model.eval()

In [None]:
# Plot A Set of Test Images
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Test Sample", test_sample), 4, 4, axis)

In [None]:
# Plot A Set of Test Fusion Images
fig, axis = plt.subplots(4, 3, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Test Fusion Sample", [x for y in fusion_test_sample for x in y]), 4, 3, axis)

In [None]:
with torch.no_grad():
    reconstructed = model(test_sample.to(device)).cpu().detach()

    # Get reconstruction of our sample
    fusion_sample_base, fusion_sample_fusee, fusion_sample_fusion = fusion_test_sample[:, 0], fusion_test_sample[:, 1], fusion_test_sample[:, 2]
    # Sample Fusion
    fusion_sample_midpoint_embedding = (model.encoder(fusion_sample_base.to(device)) * 0.4) + (model.encoder(fusion_sample_fusee.to(device)) * 0.6)
    fusion_sample_fusion = model.decoder(fusion_sample_midpoint_embedding.to(device))
    # Sample Base Images
    fusion_sample_base = model(fusion_sample_base.to(device))
    fusion_sample_fusee = model(fusion_sample_fusee.to(device))
    fusion_test_sample_reconstruction = torch.stack((fusion_sample_base, fusion_sample_fusee, fusion_sample_fusion), dim=1).flatten(end_dim=1).cpu().detach()

In [None]:
# Plot A Set of Test Images
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Reconstructed Test", reconstructed), 4, 4, axis)

In [None]:
# Plot A Set of Test Fusion Images
fig, axis = plt.subplots(4, 3, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Test Fusion Sample", fusion_test_sample_reconstruction), 4, 3, axis)

# Compute Metrics & Save Outputs

In [None]:
all_mse = []
all_ssim = []

# Testing Loop - Standard
with torch.no_grad():
    for iteration, batch in enumerate(tqdm(test_dataloader)):
        # Move batch to device
        filenames, image = batch
        image = image.to(device)

        # Run our model & get outputs
        reconstructed = model(image)

        # Calculate Metrics
        mse = nn.functional.mse_loss(reconstructed, image)
        ssim_score = ssim(reconstructed, image, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03))

        # Add metrics to tracking list
        all_mse.append(mse.detach().cpu().numpy())
        all_ssim.append(ssim_score.detach().cpu().numpy())
       
        # Save
        reconstructed = reconstructed.permute(0, 2, 3, 1).detach().cpu().numpy()
        for image, filename in zip(reconstructed, filenames):
            plt.imsave(os.path.join(output_dir, filename), image)

mse = np.asarray(all_mse).mean()
ssim_score = np.asarray(all_ssim).mean()
print(f"\nMSE = {mse}, SSIM = {ssim_score}")

In [None]:
# Testing Loop - Fusions
all_mse = []
all_mse_autoencoded = []
all_ssim = []
all_ssim_autoencoded = []
with torch.no_grad():
    for iteration, batch in enumerate(tqdm(test_fusion_dataloader)):
        # Move batch to device
        (base_filenames, fusee_filenames, fusion_filenames), (base, fusee, fusion) = batch # Returns key, value for each Pokemon
        base = base.to(device)
        fusee = fusee.to(device)
        fusion = fusion.to(device)

        # Run Model
        # Get Encoder Output
        base_embedding = model.encoder(base)
        fusee_embedding = model.encoder(fusee)
        # Midpoint Embedding
        midpoint_embedding = (base_embedding * 0.4) + (fusee_embedding * 0.6)
        # Get Output Fusion of combining two Pokemon 
        fusion_fused_output = model.decoder(midpoint_embedding)
        # Get output of autoencoder on fusion
        fusion_ae_output = model(fusion)

        # Calculate Metrics
        # Print Metrics - Fusion vs Input Fusion
        mse = nn.functional.mse_loss(fusion_fused_output, fusion)
        ssim_score = ssim(fusion_fused_output, fusion, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03))
        # Print Metrics - Fusion vs Autoencoded Fusion
        mse_autoencoded = nn.functional.mse_loss(fusion_fused_output, fusion_ae_output)
        ssim_score_autoencoded = ssim(fusion_fused_output, fusion_ae_output, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03))

        # Add metrics to tracking list
        all_mse.append(mse.detach().cpu().numpy())
        all_ssim.append(ssim_score.detach().cpu().numpy())
        all_mse_autoencoded.append(mse_autoencoded.detach().cpu().numpy())
        all_ssim_autoencoded.append(ssim_score_autoencoded.detach().cpu().numpy())

        # Save
        fusion_fused_output = fusion_fused_output.permute(0, 2, 3, 1).detach().cpu().numpy()
        for image, filename in zip(fusion_fused_output, fusion_filenames):
            plt.imsave(os.path.join(fusion_output_dir, filename), image)

mse = np.asarray(all_mse).mean()
ssim_score = np.asarray(all_ssim).mean() 
mse_autoencoded = np.asarray(all_mse_autoencoded).mean()
ssim_score_autoencoded = np.asarray(all_ssim_autoencoded).mean()
print(f"\nFusion vs Input Fusion:\nMSE = {mse}, SSIM = {ssim_score}")
print(f"\nFusion vs Autoencoded Fusion\nMSE = {mse_autoencoded}, SSIM = {ssim_score_autoencoded}")

# Save Loss Graph

In [None]:
plt.figure(figsize=(8, 6), dpi=100)
ax = plt.subplot()
plt.plot([x for x in all_val_loss], label="Validation Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.savefig(loss_output_path)

In [None]:
plt.figure(figsize=(8, 6), dpi=100)
ax = plt.subplot()
plt.plot([x for x in all_train_fusion_loss], label="Fusion Train Loss")
plt.plot([x for x in all_val_fusion_loss], label="Fusion Validation Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.savefig(fusion_loss_output_path)

# Save Model

In [None]:
torch.save(model.state_dict(), model_output_path)

# Save Animation

In [None]:
Writer = animation.writers['ffmpeg']
writer = Writer()
anim.save(animation_output_path, writer=writer)

In [None]:
Writer = animation.writers['ffmpeg']
writer = Writer()
fusion_anim.save(fusion_animation_output_path, writer=writer)

In [None]:
output_prefix

In [None]:
ls /content/drive/MyDrive/CMPUT652_PCGML/outputs/finetune_fusion_autoencoder_v1

In [None]:
drive.flush_and_unmount()