In [None]:
# default_exp wandb_viz

# Weights and Biases Visualizations

> This module offers useful visualizations using Weights and Biases.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide 
# %load_ext autoreload
# %autoreload 2

In [None]:
#export
import os
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Plot basic charts using W&B. 

In [None]:
#export
def wandb_plot_basic_charts(chart_type='line', x_data=None, y_data=None, x_name=None, 
                            y_name=None, chart_title=None, wandb_chart_name='basic-chart'):
    """
    wandb_plot_basic_charts(x_data=None, y_data=None, x_name=None, y_name=None, chart_title=None)
    
    Plot basic charts.
    
    Parameters
    ----------
    chart_type: str
        Choose a basic chart type to plot. 
        - 'line'
        - 'bar'
        - 'scatter'
    x_data : list, optional
        x-axis data.
    y_data : list, optional
        y-axis data.
    x_name : str, optional
        x-axis name.
    y_name : str, optional
        y-axis name.
    chart_title : str, optional
        Chart title.
    wandb_chart_name : str, optional
        Chart name.
    
    Returns
    -------
    None
    
    Examples
    --------
    # Plot a simple line chart.
    wandb_plot_basic_charts(chart_type='line', x_data=[1,2,3,4,5], y_data=[1,2,3,4,5], x_name="x", y_name="y", chart_title="Simple line chart")
    
    # Plot a simple bar chart.
    wandb_plot_basic_charts(chart_type='bar', x_data=['a','b','c','d','e'], y_data=[1,2,3,4,5], x_name="x", y_name="y", chart_title="Simple bar chart")
    Note: When plotting a bar chart, x_data is the labels while y_data is the values.
    
    # Plot a simple scatter plot. 
    wandb_plot_basic_charts(chart_type='scatter', x_data=[1,2,3,4,5], y_data=[1,2,3,4,5], x_name="x", y_name="y", chart_title="Simple scatter chart")
    """
    # [TODO] Should we initialize W&B run inside the function?

    data = [[x, y] for (x, y) in zip(x_data, y_data)]
    table = wandb.Table(data=data, columns = [x_name, y_name])
    
    if chart_type == 'line':
        wandb.log({f"{wandb_chart_name}" : wandb.plot.line(table, x_name, y_name, title=chart_title)})
    elif chart_type == 'bar':
        wandb.log({f"{wandb_chart_name}" : wandb.plot.bar(table, x_name, y_name, title=chart_title)})
    elif chart_type == 'scatter':
        wandb.log({f"{wandb_chart_name}" : wandb.plot.scatter(table, x_name, y_name, title=chart_title)})


In [None]:
# x_values = ['a', 'b', 'c', 'd', 'e']
# y_values = [1,2,3,4,5]

# run = wandb.init(entity='ayush-thakur', project='tests')
# wandb_plot_basic_charts(chart_type='bar', x_data=x_values, y_data=y_values, x_name="x", y_name="y", chart_title="Simple scatter chart")
# run.finish()

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_data.ipynb.
Converted 01_preprocess.ipynb.
Converted 02_utils.ipynb.
Converted 03_wandb_utils.ipynb.
Converted index.ipynb.
