# Time Series Style Transfer

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import time
from configs.Metric import Metric
from configs.SimulatedData import Proposed
from dataset.tf_pipeline import convert_dataframe_to_tensorflow_sequences
from datetime import datetime
import io
import os

gpus = tf.config.list_physical_devices('GPU')

if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

gpus

In [None]:
config = Proposed()

D1_PATH = "data/simulated_dataset/01 - Source Domain.h5"
D2_PATH = "data/simulated_dataset/amplitude_shift/4.5_4.5.h5"

SEQUENCE_LENGTH = config.sequence_lenght_in_sample
GRANUARITY = config.granularity
OVERLAP = config.overlap
BS =  config.batch_size
EPOCHS = config.epochs
NUM_SEQUENCE_TO_GENERATE = config.met_params.sequence_to_generate

STYLE_VECTOR_SIZE = 32
FEAT_WIENER = 2
N_SAMPLE_WIENER = SEQUENCE_LENGTH//4
NOISE_DIM= (N_SAMPLE_WIENER, FEAT_WIENER)
N_VALIDATION_SEQUENCE = 500
TRIPLET_R = 10

# Load the Datasets

In [None]:
df_d1= pd.read_hdf(D1_PATH).astype(np.float32)
df_d2= pd.read_hdf(D2_PATH).astype(np.float32)


dset_d1 = convert_dataframe_to_tensorflow_sequences(
    df_d1, 
    SEQUENCE_LENGTH, 
    GRANUARITY, 
    int(OVERLAP* SEQUENCE_LENGTH),
    BS
)

dset_d2 = convert_dataframe_to_tensorflow_sequences(
    df_d2, 
    SEQUENCE_LENGTH, 
    GRANUARITY, 
    int(OVERLAP* SEQUENCE_LENGTH),
    BS
)

dset_d1_train = dset_d1.skip(N_VALIDATION_SEQUENCE)
dset_d1_valid = dset_d1.take(N_VALIDATION_SEQUENCE)

dset_d2_train = dset_d2.skip(N_VALIDATION_SEQUENCE)
dset_d2_valid = dset_d2.take(N_VALIDATION_SEQUENCE)

In [None]:
x = next(iter(dset_d1))

mean, variance = tf.nn.moments(x, axes=[1], keepdims=True)
standard_dev = tf.sqrt(variance)

standard_dev.shape

## Make Content Encoder

In [None]:
# Define AdaIN Layers for Time Series
class AdaIN(tf.keras.layers.Layer):
    def __init__(self):
        super(AdaIN, self).__init__()

    def get_mean_std(self, x, eps=1e-5):
        _mean, _variance = tf.nn.moments(x, axes=[1], keepdims=True)
        standard_dev = tf.sqrt(_variance+ eps)
        return _mean, standard_dev

    def call(self, content_input, style_input):
        # print(content_input.shape, style_input.shape)
        content_mean, content_std = self.get_mean_std(content_input)
        style_mean, style_std = self.get_mean_std(style_input)
        adain_res =style_std* (content_input - content_mean) / content_std+ style_mean
        return adain_res

In [None]:
def make_content_encoder(seq_length:int, n_feat:int, feat_wiener:int):
    _input = tf.keras.Input((seq_length, n_feat,))

    x = tf.keras.layers.Conv1D(128, 5, 2, padding='same')(_input)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv1D(feat_wiener, 5, 2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)

    model = tf.keras.Model(_input, x)

    return model

def make_style_encoder(seq_length:int, n_feat:int, vector_output_shape:int):
    _input = tf.keras.Input((seq_length, n_feat))

    x = tf.keras.layers.Conv1D(128, 5, 2, padding='same')(_input)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv1D(128, 5, 2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(vector_output_shape)(x)

    model = tf.keras.Model(_input, x)

    return model


def make_decoder(n_sample_wiener:int, feat_wiener:int, style_vector_size:int, out_feat:int):
    _content_input = tf.keras.Input((n_sample_wiener, feat_wiener))
    _style_input = tf.keras.Input((style_vector_size, 1))

    _upsampled_style= tf.keras.layers.UpSampling1D()(_style_input)

    x = tf.keras.layers.Conv1DTranspose(128, 5, 2, padding='same')(_content_input)
    x = AdaIN()(x, _style_input)
    x = tf.keras.layers.Conv1DTranspose(128, 5, 1, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv1DTranspose(out_feat, 5, 2, padding='same')(x)
    x = AdaIN()(x, _upsampled_style)
    x = tf.keras.layers.Conv1DTranspose(out_feat, 5, 1, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)

    model = tf.keras.Model([_content_input, _style_input], x)
    return model


def make_discriminator(seq_length:int, n_feat:int):
    _input = tf.keras.Input((seq_length, n_feat))
    x = tf.keras.layers.Conv1D(16, 5, 2, padding='same')(_input)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Conv1D(16, 5, 2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Flatten()(x)
    _output = tf.keras.layers.Dense(1)(x)

    model = tf.keras.Model(_input, _output)

    return model

In [None]:
content_encoder = make_content_encoder(SEQUENCE_LENGTH, df_d1.shape[1], FEAT_WIENER)
# content_encoder.summary()

In [None]:
style_encoder = make_style_encoder(SEQUENCE_LENGTH, df_d1.shape[1], STYLE_VECTOR_SIZE)
# style_encoder.summary()

In [None]:
decoder = make_decoder(N_SAMPLE_WIENER, FEAT_WIENER, STYLE_VECTOR_SIZE ,df_d1.shape[1])
# decoder.summary()

In [None]:
global_discriminator = make_discriminator(SEQUENCE_LENGTH, df_d1.shape[1])

In [None]:
class StyleTransferModel(tf.keras.Model):
    def __init__(self, seq_length:int, n_feat:int, style_vector_shape:int, n_sample_wiener:int, feat_wiener:int, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.content_encoder= make_content_encoder(seq_length, n_feat, feat_wiener)
        self.style_encoder = make_style_encoder(seq_length, n_feat, style_vector_shape)
        self.decoder = make_decoder(n_sample_wiener, feat_wiener, style_vector_shape, n_feat)

    def __call__(self, content_sequence, style_sequence):
        encoded_content = self.content_encoder(content_sequence)
        encoded_style = self.style_encoder(style_sequence)

        print(encoded_content.shape, encoded_style.shape)

        generated_sequence = self.decoder([encoded_content, encoded_style])

        return generated_sequence

In [None]:
def generate(content_batch, style_batch):
    content = content_encoder(content_batch, training=False)
    style = style_encoder(style_batch, training=False)
    generated = decoder([content, style], training=False)
    return generated

In [None]:
seed_content_batch = next(iter(dset_d1_valid))
seed_style_batch = next(iter(dset_d2_valid))

generated_sequence = generate(seed_content_batch, seed_style_batch)

In [None]:
def plot_style_transfered_time_series(content_ts, style_ts, generated_ts, show=True, save_to:str=None):
    all_values = np.array([content_ts, style_ts, generated_ts])
    _min, _max = np.min(all_values)-1, np.max(all_values)+ 1
    n_series = 3

    fig = plt.figure(figsize=(18, 10))
    plt.suptitle("Visualization of the generations", fontsize=18)
    for i in range(n_series):
        ax = plt.subplot(3, n_series, n_series* i+ 1)
        ax.set_title(f"*[{i}]* Content Time Series >")
        ax.plot(content_ts[i])
        ax.set_ylim(_min, _max)
        ax.grid(True)

        ax = plt.subplot(3, n_series, n_series* i+ 2)
        ax.set_title(f"*[{i}]* Style Time Series. >")
        ax.plot(style_ts[i])
        ax.set_ylim(_min, _max)
        ax.grid(True)

        ax = plt.subplot(3, n_series, n_series* i+ 3)
        ax.set_title(f"*[{i}]* Generated Time Series.")
        ax.plot(generated_ts[i])
        ax.set_ylim(_min, _max)
        ax.grid(True)

    plt.tight_layout()

    if not save_to is None:
        fig.savefig(save_to)
    
    if show:
        plt.show()

    return fig

generation_figure = plot_style_transfered_time_series(seed_content_batch, seed_style_batch, generated_sequence)

In [None]:
def fig_to_buff(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    plt.close(fig)
    return buf

## Define losses.

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def recontruction_loss(true:tf.Tensor, generated:tf.Tensor):
    diff = generated- true
    result = tf.math.reduce_mean(tf.square(diff))
    return result

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


def fixed_point_content(encoded_content_real, encoded_content_fake):
    diff = encoded_content_fake- encoded_content_real
    return tf.reduce_mean(tf.square(diff))



## Tensorboard Logs.

In [None]:
# Train Metrics
met_generator_train = tf.keras.metrics.Mean(name="Train generator Loss")
met_style_encoder_train = tf.keras.metrics.Mean(name="Train Style Encoder Loss")
met_content_encoder_train= tf.keras.metrics.Mean(name="Train Content Encoder Loss")
met_disc_loss_train= tf.keras.metrics.Mean(name="Train Discriminatir Loss")

# Valid Metrics
met_generator_valid = tf.keras.metrics.Mean(name="valid generator Loss")
met_style_encoder_valid = tf.keras.metrics.Mean(name="valid Style Encoder Loss")
met_content_encoder_valid= tf.keras.metrics.Mean(name="valid Content Encoder Loss")
met_disc_loss_valid= tf.keras.metrics.Mean(name="valid Discriminatir Loss")

In [None]:
date_str = datetime.now().strftime('%Y-%m-%d_%H_%M_%S')

BASE_DIR = f"logs/{date_str} - Style Transfer Algorithm"
TRAIN_LOGS_DIR_PATH = f"{BASE_DIR}/train"
VALID_LOGS_DIR_PATH = f"{BASE_DIR}/valid"
GENERATION_LOG = f"{BASE_DIR}/Generations"
os.makedirs(GENERATION_LOG)

TRAIN_SUMMARY_WRITER = tf.summary.create_file_writer(TRAIN_LOGS_DIR_PATH)
VALID_SUMMARY_WRITER = tf.summary.create_file_writer(VALID_LOGS_DIR_PATH)

def log_train_losses(epoch, plot_buf):
    image = tf.image.decode_png(plot_buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)

    with TRAIN_SUMMARY_WRITER.as_default():
        tf.summary.scalar("Generator Loss", met_generator_train.result(), step=epoch)
        tf.summary.scalar("Style Loss", met_style_encoder_train.result(), step=epoch)
        tf.summary.scalar("Content Loss", met_content_encoder_train.result(), step=epoch)
        tf.summary.scalar("Discriminator Loss", met_disc_loss_train.result(), step=epoch)
        
        tf.summary.image("Training Generations", image, step=epoch)


def log_valid_losses(epoch):
    with VALID_SUMMARY_WRITER.as_default():
        tf.summary.scalar("Generator Loss", met_generator_valid.result(), step=epoch)
        tf.summary.scalar("Style Loss", met_style_encoder_valid.result(), step=epoch)
        tf.summary.scalar("Content Loss", met_content_encoder_valid.result(), step=epoch)
        tf.summary.scalar("Discriminator Loss", met_disc_loss_valid.result(), step=epoch)


def reset_metric_states():
    met_generator_train.reset_states()
    met_style_encoder_train.reset_states()
    met_content_encoder_train.reset_states()
    met_disc_loss_train.reset_states()

def reset_valid_states():
    met_generator_valid.reset_states()
    met_style_encoder_valid.reset_states()
    met_content_encoder_valid.reset_states()
    met_disc_loss_valid.reset_states()
        

## Instantiate model for training.

In [None]:
content_encoder = make_content_encoder(SEQUENCE_LENGTH, df_d1.shape[1], FEAT_WIENER)
style_encoder = make_style_encoder(SEQUENCE_LENGTH, df_d1.shape[1], STYLE_VECTOR_SIZE)
decoder = make_decoder(N_SAMPLE_WIENER, FEAT_WIENER, STYLE_VECTOR_SIZE ,df_d1.shape[1])

opt_content_encoder = tf.keras.optimizers.RMSprop(1e-6)
opt_style_encoder = tf.keras.optimizers.RMSprop(1e-6)
opt_decoder = tf.keras.optimizers.RMSprop(1e-3)
opt_discr = tf.keras.optimizers.RMSprop(1e-6)

## Train the model

In [None]:
from itertools import product
indexes = np.array(list(product(range(BS), range(BS))))
other_index = np.arange(BS)* BS

In [None]:
# test anchor, positive and negative sample calculation...
# anchor_indexes = np.array([ i* BS+i for i in range(BS) for _ in range(BS-1) ])
# pos_indexes = np.array([ BS*j + i for i in range(BS) for j in range(BS) if i !=j ])
# neg_indexes = np.array([ j+BS*i  for i in range(BS) for j in range(BS) if i !=j] )

# print("Anchor Samples:\n", indexes[anchor_indexes, :])
# print("Positive Samples:\n", indexes[pos_indexes, :])
# print("Negative Samples:\n", indexes[neg_indexes, :])

In [None]:
def get_style_voctor_for_dis_loss(style_vector:tf.Tensor, batch_size:int):
    anchor_indexes = np.array([ i* batch_size+i for i in range(batch_size) for _ in range(batch_size-1) ])
    return tf.gather(style_vector, anchor_indexes)

def get_anchor_positive_negative_from_batch(style_from_style_ts:tf.Tensor, style_of_generations:tf.Tensor, batch_size:int):
    anchor_indexes = np.array([ i* batch_size+i for i in range(batch_size) for _ in range(batch_size-1) ])
    pos_indexes = np.array([ batch_size*j + i for i in range(batch_size) for j in range(batch_size) if i !=j ])
    neg_indexes = np.array([ j+batch_size*i  for i in range(batch_size) for j in range(batch_size) if i !=j] )

    # Anchor are for example:
    # [(c0, s0), (c0, s0), (c0, s0), 
    # (c1, s1), (c1, s1), (c1, s1), ... (c3, s3)]
    anchors = tf.gather(style_from_style_ts, anchor_indexes)
    # Different content, Same Style:
    # [(c1, s0), (c1, s1), (c1, s2)...]
    pos_vector= tf.gather(style_of_generations, pos_indexes)
    # Same content but different style
    neg_vector = tf.gather(style_of_generations, neg_indexes)

    return anchors, pos_vector, neg_vector

def get_dissantanglement_loss_component(style_of_generations, style_of_style, batch_size:int):
    anchor_indexes = np.array([ i* batch_size+i for i in range(batch_size) for _ in range(batch_size-1) ])
    pos_indexes = np.array([ batch_size*j + i for i in range(batch_size) for j in range(batch_size) if i !=j ])

    es_y = tf.gather(style_of_style, anchor_indexes)
    es_x1_y = tf.gather(style_of_generations, anchor_indexes)
    es_x2_y = tf.gather(style_of_generations, pos_indexes)

    return es_y, es_x1_y, es_x2_y
    

def l2(x:tf.Tensor, y:tf.Tensor):
    diff = tf.square(x- y)
    _distance = tf.reduce_sum(diff, axis=-1)
    return tf.sqrt(_distance)


def fixed_point_triplet_style_loss(anchor_encoded_style, positive_encoded_style, negative_encoded_style):
    # shape: [BS, Style_length]

    positive_distance = l2(positive_encoded_style, negative_encoded_style)
    negative_distance = l2(positive_encoded_style, anchor_encoded_style)

    triplet = TRIPLET_R+ positive_distance- negative_distance
    zeros = tf.zeros_like(triplet)
    triplet = tf.math.maximum(triplet, zeros)

    loss = tf.reduce_mean(triplet)

    return loss

def fixed_point_disentanglement(
        es_x1_y:tf.Tensor, 
        es_x2_y:tf.Tensor, 
        es_y:tf.Tensor
        ):

    diff1 = l2(es_x1_y, es_x2_y)
    diff2 = l2(es_x1_y, es_y)

    loss = diff1- diff2
    zeros = tf.zeros_like(loss)
    loss = tf.math.maximum(loss, zeros)
    loss = tf.reduce_mean(loss)
    return loss

In [None]:
@tf.function
def train_step(content_batch:tf.Tensor, style_batch:tf.Tensor):
    lambda_reconstr= 1
    lambda_realness= 1
    lambda_adv= 1
    lambda_content= 1
    lambda_triplet= 1
    lambda_dis= 1

    with tf.GradientTape() as content_tape, tf.GradientTape() as style_tape, tf.GradientTape() as decoder_tape, tf.GradientTape() as discr_tape:
        extended_content_images = tf.gather(content_batch, indexes[:, 0])
        extended_style_images = tf.gather(style_batch, indexes[:, 1])
        
        # Get the content form the content batch
        content_of_content = content_encoder(extended_content_images, training=True)
        # Get the Style from the style batch
        style_of_style = style_encoder(extended_style_images, training=True)

        # Generate the time series given the content and the style.
        generated_ts = decoder([content_of_content, style_of_style], training=True)

        # Get the content and style form generated data
        content_of_generations = content_encoder(generated_ts, training=True)
        style_of_generations = style_encoder(generated_ts, training=True)

        # Get content image stylized with a image from the style batch
        # E.g.  (c1, s1), (c2, s2)... (cn, sn)
        reduced_stylized = tf.gather(generated_ts, other_index)

        crit_on_fake = global_discriminator(reduced_stylized, training=True)
        crit_on_real = global_discriminator(content_batch, training=True)

        reconstr_from_content = recontruction_loss(extended_content_images, generated_ts) 
        realness = generator_loss(reduced_stylized)

        global_dicriminator_loss = discriminator_loss(crit_on_real, crit_on_fake)
        
        content_similarity = fixed_point_content(content_of_content, content_of_generations)

        anchors, positive_vectors, negative_vectors = get_anchor_positive_negative_from_batch(style_of_style, style_of_generations, BS)
        es_y, es_x1_y, es_x2_y =  get_dissantanglement_loss_component(style_of_generations, style_of_style, BS)

        # sorted_styles = get_style_voctor_for_dis_loss(style_of_style, BS)
        triplet_style = fixed_point_triplet_style_loss(anchors, positive_vectors, negative_vectors)

        dis_loss = fixed_point_disentanglement(es_y, es_x1_y, es_x2_y)

        d_loss = lambda_adv* global_dicriminator_loss
        content_encoder_loss = lambda_content* content_similarity
        style_encoder_loss =  lambda_triplet* triplet_style + lambda_dis* dis_loss 
        g_loss = lambda_reconstr* reconstr_from_content+ lambda_realness* realness + lambda_content* content_similarity + lambda_triplet* triplet_style + lambda_dis* dis_loss 


    content_grad=content_tape.gradient(content_encoder_loss, content_encoder.trainable_variables)
    style_grad = style_tape.gradient(style_encoder_loss, style_encoder.trainable_variables)
    decoder_grad = decoder_tape.gradient(g_loss, decoder.trainable_variables)
    discr_grads = discr_tape.gradient(d_loss, global_discriminator.trainable_variables)

    opt_content_encoder.apply_gradients(zip(content_grad, content_encoder.trainable_variables))
    opt_style_encoder.apply_gradients(zip(style_grad, style_encoder.trainable_variables))
    opt_decoder.apply_gradients(zip(decoder_grad, decoder.trainable_variables))
    opt_discr.apply_gradients(zip(discr_grads, global_discriminator.trainable_variables))


    met_generator_train(g_loss)
    met_style_encoder_train(style_encoder_loss)
    met_content_encoder_train(content_encoder_loss)
    met_disc_loss_train(d_loss)

@tf.function
def valid_step(content_batch:tf.Tensor, style_batch:tf.Tensor):
    lambda_reconstr= 1
    lambda_realness= 1
    lambda_adv= 1
    lambda_content= 1
    lambda_triplet= 1
    lambda_dis= 1

    extended_content_images = tf.gather(content_batch, indexes[:, 0])
    extended_style_images = tf.gather(style_batch, indexes[:, 1])
    
    # Get the content form the content batch
    content_of_content = content_encoder(extended_content_images, training=True)
    # Get the Style from the style batch
    style_of_style = style_encoder(extended_style_images, training=True)

    # Generate the time series given the content and the style.
    generated_ts = decoder([content_of_content, style_of_style], training=True)

    # Get the content and style form generated data
    content_of_generations = content_encoder(generated_ts, training=True)
    style_of_generations = style_encoder(generated_ts, training=True)

    # Get content image stylized with a image from the style batch
    # E.g.  (c1, s1), (c2, s2)... (cn, sn)
    reduced_stylized = tf.gather(generated_ts, other_index)

    crit_on_fake = global_discriminator(reduced_stylized, training=True)
    crit_on_real = global_discriminator(content_batch, training=True)

    reconstr_from_content = recontruction_loss(extended_content_images, generated_ts) 
    realness = generator_loss(reduced_stylized)

    global_dicriminator_loss = discriminator_loss(crit_on_real, crit_on_fake)
    
    content_similarity = fixed_point_content(content_of_content, content_of_generations)

    anchors, positive_vectors, negative_vectors = get_anchor_positive_negative_from_batch(style_of_style, style_of_generations, BS)
    es_y, es_x1_y, es_x2_y =  get_dissantanglement_loss_component(style_of_generations, style_of_style, BS)

    # sorted_styles = get_style_voctor_for_dis_loss(style_of_style, BS)
    triplet_style = fixed_point_triplet_style_loss(anchors, positive_vectors, negative_vectors)

    dis_loss = fixed_point_disentanglement(es_y, es_x1_y, es_x2_y)

    g_loss = lambda_reconstr* reconstr_from_content+ lambda_realness* realness
    d_loss = lambda_adv* global_dicriminator_loss
    content_encoder_loss = lambda_content* content_similarity
    style_encoder_loss =  lambda_triplet* triplet_style + lambda_dis* dis_loss 

    met_generator_valid(g_loss)
    met_style_encoder_valid(style_encoder_loss)
    met_content_encoder_valid(content_encoder_loss)
    met_disc_loss_valid(d_loss)



In [None]:
def train():
    total_batch = "?"
    force_valid_steps = 5000
    for e in range(EPOCHS):
        reset_metric_states()
        reset_valid_states()
        filename = f'{GENERATION_LOG}/{e}.png'
        print("[+] Train Step...")
        for i, (content_batch, style_batch) in enumerate(zip(dset_d1_train, dset_d2_train)):
            train_step(content_batch, style_batch)
            print(f"\r e:{e}/{EPOCHS}; {i}/{total_batch}. G_loss {met_generator_train.result():0.2f} style loss {met_style_encoder_train.result():0.2f} content loss {met_content_encoder_train.result():0.2f} discr loss {met_disc_loss_train.result():0.2f}        ", end="")
            # return

        print()
        print("[+] Validation Step...")
        for vb, (content_batch, style_batch) in enumerate(zip(dset_d1_valid, dset_d2_valid)):
            valid_step(content_batch, style_batch)
            print(f"\r e:{e}/{EPOCHS}; {vb+1}/500. G_loss {met_generator_valid.result():0.2f} style loss {met_style_encoder_valid.result():0.2f} content loss {met_content_encoder_valid.result():0.2f} discr loss {met_disc_loss_valid.result():0.2f}        ", end="")
    
        # Make Generations
        generations = generate(seed_content_batch, seed_style_batch)
        vis_fig = plot_style_transfered_time_series(seed_content_batch, seed_style_batch, generations, show=False, save_to=filename)
        plot_buff = fig_to_buff(vis_fig)
        log_train_losses(e, plot_buff)
        log_valid_losses(e)
        print()

        if e == 0:
            total_batch = i 

        
train()

In [None]:
generated_sequence = generate(seed_content_batch, seed_style_batch)

plot_style_transfered_time_series(seed_content_batch, seed_style_batch, generated_sequence, show=False)