In [1]:
def compare_hist_real_vs_generated(model, other_gen_dataset = None, n_img_horiz = 4, n_bins = 20, figsize = None,
                                   discrete_xtick_rotation = 45,
                                   title = None, seed = None, save_path = None, save_dir = None):
    if (seed != None):
        tf.random.set_seed(seed)
    if other_gen_dataset is None:
        gen_data = model.generate_data()
    else:
        gen_data = other_gen_dataset
    n_img_vert = (model.n_columns-1) // n_img_horiz + 1
    if(figsize == None):
        figsize = (15, 3*n_img_vert)
    fig, ax = plt.subplots(n_img_vert, n_img_horiz, figsize = figsize)
    fig.suptitle(title)
    img_counter_horiz = 0
    img_counter_vert = 0
    img_counter = 0
    for num_col in model.columns_num:
        min_val = min(min(gen_data[num_col]), min(model.data[num_col]))
        max_val = max(max(gen_data[num_col]),  max(model.data[num_col]))
        bins = np.linspace(min_val, max_val, n_bins + 1, dtype = np.float)
        ax[img_counter_vert, img_counter_horiz].hist(model.data[num_col], density = True, alpha = 0.5,
                                                     bins = bins, label = "Real")
        ax[img_counter_vert, img_counter_horiz].hist(gen_data[num_col], density = True, alpha = 0.5,
                                                     bins = bins, label = "Gen")
        ax[img_counter_vert, img_counter_horiz].set_title(num_col)
        ax[img_counter_vert, img_counter_horiz].legend()
        img_counter_horiz += 1
        img_counter += 1
        if (img_counter_horiz == n_img_horiz):
            img_counter_horiz = 0
            img_counter_vert += 1
    for discrete_col in model.columns_discrete:
        unique_values = np.unique(model.data[discrete_col])
        map_discr_to_int = {s : i for i,s in enumerate(unique_values)}
        bins = np.arange(0, unique_values.size, 0.5) - 0.25
        x_ticks = np.arange(0, unique_values.size)
        ax[img_counter_vert, img_counter_horiz].hist(model.data[discrete_col].map(map_discr_to_int),
                                                     bins = bins, density = True, alpha = 0.5, label = "Real")
        ax[img_counter_vert, img_counter_horiz].hist(gen_data[discrete_col].map(map_discr_to_int),
                                                     bins = bins, density = True, alpha = 0.5, label = "Gen")
        ax[img_counter_vert, img_counter_horiz].set_title(discrete_col)
        ax[img_counter_vert, img_counter_horiz].set_xticks(x_ticks)
        ax[img_counter_vert, img_counter_horiz].set_xticklabels(unique_values)
        ax[img_counter_vert, img_counter_horiz].tick_params(axis='x', labelrotation = discrete_xtick_rotation)
        ax[img_counter_vert, img_counter_horiz].legend()
        img_counter_horiz += 1
        img_counter += 1
        if (img_counter_horiz == n_img_horiz):
            img_counter_horiz = 0
            img_counter_vert += 1
    for i in range(n_img_horiz*n_img_vert - img_counter):
        ax[img_counter_vert, img_counter_horiz].axis("off")
        img_counter_horiz += 1
        if (img_counter_horiz == n_img_horiz):
            img_counter_horiz = 0
            img_counter_vert += 1
    fig.tight_layout()
    
    if not (save_path is None):
        if not save_dir is None:
            os.makedirs(save_dir, exist_ok = True)
            save_path = os.path.join(save_dir, save_path)
        plt.savefig(save_path)
    
    plt.close(fig)
    
    return(fig)

In [3]:
def map_str_to_color(vec, color_map = "viridis"):
    cmap = matplotlib.cm.get_cmap(color_map)
    vec_set = set(vec)
    colors = cmap(np.linspace(0, 1, len(vec_set)))
    clr_map = {string : colors[i] for i, string in enumerate(vec_set)}
    return([clr_map[s] for s in vec])

In [None]:
def compare_evolution_hist_real_vs_generated(tg, epochs = None, name = "compare_hist_real_vs_generated", fps = 1,
                                             mult_time_first_image = 1, mult_time_last_image = 1, **kwargs):
    name += ".gif"
    if epochs is None:
        ckpts = tg.ckpt_manager.checkpoints
        epochs = [int(ckpt.replace(tg.ckpt_prefix + "-", "")) for ckpt in ckpts]
    
    dir_path = ".//temp_gif_compare_hist_real_vs_generated//"
    os.makedirs(dir_path, exist_ok = True)
    filenames = []
    for i in tqdm(range(len(epochs))):
        # plot the line chart
        tg.restore_checkpoint(epoch = epochs[i])
        
        # create file name and append it to a list
        filename = f'{dir_path}{epochs[i]}.jpg'
        filenames.append(filename)
        
        # Create figure for current epoch
        fig = compare_hist_real_vs_generated(tg, title = "Epoch %d" % (epochs[i]), save_path = filename)

    # build gif
    last_image_i = len(epochs)
    with imageio.get_writer(name, mode='I', fps = 1) as writer:
        for i, filename in enumerate(filenames):
            image = imageio.imread(filename)
            writer.append_data(image)
            if (i == 0):
                for j in range(1, mult_time_first_image):
                    writer.append_data(image)
            if (i == last_image_i):
                for j in range(1, mult_time_last_image):
                    writer.append_data(image)

    # Remove files
    shutil.rmtree(dir_path)
    
    return name

In [1]:
def play_gif(filename, fps = None):
    if (fps == None):
        return(Image(filename))
    else:
        gif = imageio.mimread(filename)
        gif_change_speed = "temp_gif_change_speed.gif"
        imageio.mimsave(gif_change_speed, gif, fps=fps)
        video = Image(gif_change_speed)
        os.remove(gif_change_speed)
        return(video)

In [1]:
def gif_to_mp4(filename_gif, fps = None, new_filename = None):
    if(new_filename == None):
        new_filename = f"{filename_gif[:-4]}.mp4"
    
    if (fps == None):
        gif_changed_speed = filename_gif
    else:
        gif = imageio.mimread(filename_gif_changed_speed)
        gif_changed_speed = "temp_gif_change_speed.gif"
        imageio.mimsave(gif_change_speed, gif, fps=fps)
        video = Image(gif_change_speed)
    
    clip = moviepy.editor.VideoFileClip(gif_changed_speed)
    clip.write_videofile(new_filename)
    if(fps != None):
        os.remove(gif_changed_speed)
    return(new_filename)

In [2]:
def compute_nmi_matrix(tgan = None, dataset = None, bins = None, n_q_bins = 40
                       , generated_data = True, retbins = False, average_method = "arithmetic"):
    if dataset is None:
        if (generated_data):
            data_binned = tgan.generate_data()
        else:
            data_binned = tgan.data.copy()
    else:
        data_binned = dataset.copy()
    if (retbins):
        bins_curr = {}
    for col_num in tgan.columns_num:
        if bins == None:
            cut_series, cut_bins = pd.qcut(data_binned[col_num] , q = n_q_bins, retbins=True, duplicates = "drop")
        else:
            cut_series, cut_bins = pd.cut(data_binned[col_num], bins = bins[col_num], retbins = True,
                                         include_lowest = True)
        data_binned[col_num] = cut_series
        if (retbins):
            bins_curr[col_num] = cut_bins
            
    if average_method == "arithmetic":
        average_func = lambda x,y : np.mean([x,y])
    elif average_method == "max":
        average_func = lambda x, y : np.max([x,y])
    elif average_method == "min":
        average_func = lambda x, y : np.min([x,y])
    elif average_method == "geometric":
        average_func = lambda x, y : np.sqrt(x * y)
    else:
        raise ValueError("Average_method given as input is not implemented")
    
    probs_dict = {} 
    entropy = np.zeros(tg.n_columns)
    for i,col in enumerate(tgan.columns):
        col_category_fractions = data_binned[col].value_counts(normalize = True)
        probs_dict[col] = col_category_fractions.to_dict()
        entropy[i] = np.sum(- col_category_fractions * np.log(col_category_fractions))
    
    nmi_matrix = np.zeros([tg.n_columns, tg.n_columns])
    for i,col1 in enumerate(tg.columns):
        for j,col2 in enumerate(tg.columns):
            if j < i:
                continue
            elif i == j:
                nmi_matrix[i,j] = 1
                continue
            df_curr_cols = data_binned[[col1,col2]].copy()
            df_curr_cols_fraction = df_curr_cols.groupby([col1,col2]).size().reset_index().rename(columns={0:"Prob.both"})
            df_curr_cols_fraction["Prob.both"] /= data_binned.shape[0]
            df_curr_cols_fraction["Prob.col1"] = df_curr_cols_fraction[col1].map(probs_dict[col1]).astype(float)
            df_curr_cols_fraction["Prob.col2"] = df_curr_cols_fraction[col2].map(probs_dict[col2]).astype(float)
            df_curr_cols_fraction["NMI"] = df_curr_cols_fraction["Prob.both"] * np.log(df_curr_cols_fraction["Prob.both"]/
                                                                                       (df_curr_cols_fraction["Prob.col1"]*df_curr_cols_fraction["Prob.col2"]))
            nmi_matrix[i,j] = nmi_matrix[j,i] = np.sum(df_curr_cols_fraction["NMI"]) / average_func(entropy[i], entropy[j])
    
    if retbins:
        return nmi_matrix, bins_curr
    else:
        return nmi_matrix

In [1]:
def compare_nmi_matrices(tgans, extra_datasets = None, include_true_data = True, n_q_bins = 40, ncol = None, nrow = None,
                        average_method = "arithmetic", subplot_title_true_dataset = "True dataset",
                         subplot_titles_tgans = None, subplot_titles_extra_datasets = None, figsize = [14,5],
                         compute_diff_nmi_matrices = False, save_dir = None, save_name = None, title = None):
    if (not save_dir is None) and save_name is None:
        if compute_diff_nmi_matrices:
            save_name = "nmi_diff_matrices"
        else:
            save_name = "nmi_matrices"
        
    if subplot_titles_tgans == None:
        subplot_titles_tgans = [None] * len(tgans)
    else:
        if (len(subplot_titles_tgans) != len(tgans)):
            raise ValueError("Number of tgan subplot titles must be equal to number of tgans")
    
    if include_true_data:
        subplot_titles_tgans = [subplot_title_true_dataset] + subplot_titles_tgans
        
    if not extra_datasets is None:
        datasets = [None] * len(tgans) + extra_datasets
        tgans = tgans + [tgans[0]] * len(extra_datasets)
        if (subplot_titles_extra_datasets == None):
            subplot_titles_extra_datasets = [None] * len(extra_datasets)
        else:
            if (len(subplot_titles_extra_datasets) != len(extra_datasets)):
                raise ValueError("Number of subplot titles for the extra datasets must be equal to the number of extra datasets")
        subplot_titles = subplot_titles_tgans + subplot_titles_extra_datasets
    else:
        subplot_titles = subplot_titles_tgans
        datasets = [None] * len(tgans)
    
    n_subplots = len(tgans) + (1 if include_true_data else 0)
    
    def map_fig_n_to_indices(curr_fig, ncol, nrow):
        if ncol == 1:
            return curr_fig
        elif nrow == 1:
            return curr_fig
        else:
            curr_fig_col = floor(curr_fig // ncol)
            curr_fig_row = curr_fig - curr_fig_col * ncol
            return curr_fig_col, curr_fig_row
        
    if ncol == None and nrow == None:
        nrow = 1
        ncol = n_subplots
    elif ncol == None:
        ncol = ceil(n_subplots / nrow)
    elif nrow == None:
        nrow = ceil(n_subplots / ncol)
    else:
        if (nrow * ncol < n_subplots):
            raise ValueError("ncol times nrow must be larger than number of subfigures to plot")
    
    fig, axes = plt.subplots(nrow, ncol, figsize = figsize)
    plt.tight_layout()
    
    plt.title(title)
    
    curr_fig = 0
    nmi_matrix_truth, bins = compute_nmi_matrix(tgans[curr_fig], bins = None, n_q_bins = n_q_bins, generated_data = False,
                                          retbins = True, average_method = average_method)
    
    if compute_diff_nmi_matrices:
        colors_blue = plt.cm.Blues(np.linspace(0., 1, 128))
        colors_red = np.flip(plt.cm.Reds(np.linspace(0, 1, 128)))
        colors = np.vstack((colors_red, colors_blue))
        cmap_diff_nmi = mcolors.LinearSegmentedColormap.from_list('my_blue_red_colormap', colors)
    
    if (include_true_data):
        axes_ind = map_fig_n_to_indices(0, ncol, nrow)
        if compute_diff_nmi_matrices:
            axes[axes_ind].imshow(nmi_matrix_truth - nmi_matrix_truth, cmap = cmap_diff_nmi, vmin = -1, vmax = 1)
        else:
            axes[axes_ind].imshow(nmi_matrix_truth, cmap = plt.cm.Blues)
        axes[axes_ind].set_xticks([])
        axes[axes_ind].set_yticks([])
        if subplot_titles != None:
            axes[axes_ind].set_title(subplot_titles[curr_fig])
    
    curr_tgan = 0
    for curr_fig in range(1 if include_true_data else 0, nrow * ncol):
        if (curr_fig < n_subplots):
            nmi_matrix = compute_nmi_matrix(tgans[curr_tgan], dataset = datasets[curr_tgan], n_q_bins = n_q_bins,
                                            generated_data = True, retbins = False, average_method = average_method)
            axes_ind = map_fig_n_to_indices(curr_fig, ncol, nrow)
            if compute_diff_nmi_matrices:
                nmi_matrix -= nmi_matrix_truth
                im = axes[axes_ind].imshow(nmi_matrix, cmap = cmap_diff_nmi, vmin = -1, vmax = 1)
            else:
                im = axes[axes_ind].imshow(nmi_matrix, cmap = plt.cm.Blues)
            axes[axes_ind].set_xticks([])
            axes[axes_ind].set_yticks([])
            if subplot_titles != None:
                axes[axes_ind].set_title(subplot_titles[curr_fig])
            curr_tgan += 1
        else:
            axes[map_fig_n_to_indices(curr_fig, ncol, nrow)].axis("off")

    fig.colorbar(im, ax=axes.ravel().tolist())
    if not save_dir is None:
        plt.savefig(os.path.join(save_dir, save_name))
    plt.close(fig)
    return fig
    

In [None]:
class Timer:
    def __init__(self):
        self._start_time = {}
        self.elapsed_time = {}

    def start(self, name):
        """Start a new timer"""

        self._start_time[name] = time.perf_counter()

    def stop(self, name):
        """Stop the timer, and report the elapsed time"""
        curr_elapsed_time = time.perf_counter() - self._start_time[name]
        self.elapsed_time[name] = curr_elapsed_time
        self._start_time[name] = None
        print(f"Elapsed time for {name}: {curr_elapsed_time:0.3f} seconds")
        
    def save(self, path, save_dir = None):
        if not save_dir is None:
            path = os.path.join(save_dir, path)
            os.makedirs(save_dir, exist_ok = True)
        with open(path, "wb") as f:
            pickle.dump(self, f)
    
    def load(self, path, save_dir = None):
        if not save_dir is None:
            path = os.path.join(save_dir, path)
        with open(path, "rb") as f:
            return pickle.load(f)

def load_timer(path, save_dir = None):
    if not save_dir is None:
        path = os.path.join(save_dir, path)
    with open(path, "rb") as f:
        return pickle.load(f)