In [1]:
import os

import cv2
import gdown
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt
import json
import plotly.express as px
import plotly.subplots
import plotly.graph_objs as go
import pandas as pd

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.utils import save_image
from tqdm.notebook import tqdm
#from google.colab import drive, files


from utils import *

In [2]:
df = pd.read_json('logs.json', lines =  True)

In [3]:
def short_title(title):
    if len(str(title)) > 20:
        title = title.split(' ')
        title = title[0]
        title = title[1:]
    return title

In [26]:
def plot(df, exp_title, column_to_check, exp_list=False):
    filtered_df = df[df['exp_title'] == exp_title]
    fig = plotly.subplots.make_subplots(rows=3, 
                                        cols=1, 
                                        subplot_titles=['KL', 
                                                        'Loss', 
                                                        'Sum loss'])

    columns = ['kl_avg', 'loss_avg', 'loss_sum_avg']
    
    # Define a colormap for unique colors for each trace
    num_traces = len(filtered_df)
    colorscale = px.colors.qualitative.Plotly[:num_traces]    
    filtered_df[column_to_check] = filtered_df[column_to_check].astype(str)
    print(filtered_df.head())
    for i, column in enumerate(columns):
        print('A')
        if exp_list:
            exp_counter = 0
        for index, (row_index, row) in enumerate(filtered_df.iterrows(), start=0):
            print('B')
            trace_color = colorscale[index]
            if exp_list:
                fig.add_trace(go.Scatter(x=list(range(len(row[column]))), 
                                         y=row[column],
                                         mode='lines', 
                                         name=f'{short_title(exp_list[exp_counter])}',
                                         line=dict(color=trace_color),
                                         legendgroup=exp_list[exp_counter]), 
                              row=i+1, 
                              col=1)
                exp_counter += 1
            else:
                fig.add_trace(go.Scatter(x=list(range(len(row[column]))), 
                                         y=row[column],
                                         mode='lines', 
                                         name=f'{column_to_check} {short_title(row[column_to_check])}',
                                         line=dict(color=trace_color),
                                         legendgroup=row[column_to_check]), 
                              row=i+1, 
                              col=1)
                
    names = set()
    fig.for_each_trace(
        lambda trace:
            trace.update(showlegend=False)
            if (trace.name in names) else names.add(trace.name))
    
    fig.update_layout(title_text=f'Exp Title: {exp_title}')
    fig.update_xaxes(title_text='Epoch', row=1, col=1)
    fig.update_xaxes(title_text='Epoch', row=2, col=1)
    fig.update_xaxes(title_text='Epoch', row=3, col=1)
    fig.update_yaxes(title_text=columns[0], row=1, col=1)
    fig.update_yaxes(title_text=columns[1], row=2, col=1)
    fig.update_yaxes(title_text=columns[2], row=3, col=1)
    fig.update_layout(
        width=1000,  
        height=1000  
    )
    if exp_list:
        filtered_df['true_labels'] = exp_list
        bar_fig = px.bar(filtered_df, x='true_labels', y='execution_time', title=f'Time {short_title(column_to_check)}')
    else:
        bar_fig = px.bar(filtered_df, x=column_to_check, y='execution_time', title=f'Time {short_title(column_to_check)}')

    fig.show()
    bar_fig.show()

In [27]:
def plot_png(df, exp_title, column_to_check, figsize=(15, 15), columns=2):
    filtered_df = df[df['exp_title'] == exp_title]
    num_plots = len(filtered_df)
    

    if num_plots == 0:
        print(f"No images found for {exp_title}")
        return

    rows = (num_plots+1) // columns  # Calculate the number of rows)        
    fig, axs = plt.subplots(rows, columns, figsize=figsize)
    
    for index, (_, row) in enumerate(filtered_df.iterrows(), start=0):
        image_file = f"{exp_title}_{row['exp_index']}.png"
        title = f"{exp_title} {row[column_to_check]}"

        if os.path.exists(image_file):
            img = plt.imread(image_file)
            row_index = index // columns
            col_index = index % columns

            ax = axs[row_index, col_index]
            ax.imshow(img.copy())
            ax.set_title(title)

    # Remove empty subplots if num_plots is less than rows*columns
    for i in range(num_plots, rows * columns):
        if rows > 1:
            axs[i // columns, i % columns].axis('off')
        else:
            axs[i].axis('off')

    plt.tight_layout()
    plt.show()

In [28]:
plot(df, 'test_lr', 'lr')
#plot_png(df, 'test_epochs', 'epochs')

   epochs      lr  batch_size  latent_dims model_name  \
0      30  0.0001         128            2        vae   
1      30  0.0005         128            2        vae   
2      30   0.001         128            2        vae   
3      30   0.003         128            2        vae   
4      30   0.005         128            2        vae   

                         optimizer exp_title  exp_index  execution_time  \
0  <class 'torch.optim.adam.Adam'>   test_lr          0      324.922077   
1  <class 'torch.optim.adam.Adam'>   test_lr          1      326.814675   
2  <class 'torch.optim.adam.Adam'>   test_lr          2      339.708420   
3  <class 'torch.optim.adam.Adam'>   test_lr          3      335.285763   
4  <class 'torch.optim.adam.Adam'>   test_lr          4      327.791238   

                                              kl_avg  \
0  [574.6667537124935, 483.0266427566756, 443.313...   
1  [409.48552253810584, 416.1337645963819, 445.05...   
2  [382.144795100572, 430.509170825039