Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ After completing a milestone, create a pull request with your changes for review

## PR4: Data Visualization Module

- [ ] Set up visualization framework
- [ ] Implement histogram/density plots
- [ ] Create scatter plot functionality
- [ ] Add bar chart and pie chart generators
- [ ] Implement box plots and violin plots
- [ ] Create heatmap functionality
- [ ] Add visualization customization options
- [ ] Implement visualization export capability
- [ ] Write tests for all visualization functions
- [ ] Test visualization rendering with different data inputs
- [x] Set up visualization framework
- [x] Implement histogram/density plots
- [x] Create scatter plot functionality
- [x] Add bar chart and pie chart generators
- [x] Implement box plots and violin plots
- [x] Create heatmap functionality
- [x] Add visualization customization options
- [x] Implement visualization export capability
- [x] Write tests for all visualization functions
- [x] Test visualization rendering with different data inputs

## PR5: Model Training - Classification

Expand Down
53 changes: 53 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pandas as pd
from utils import viz


def sample_df():
return pd.DataFrame({
'num1': [1, 2, 3, 4, 5],
'num2': [5, 4, 3, 2, 1],
'cat': ['a', 'b', 'a', 'b', 'a'],
})


def test_histogram_and_density():
df = sample_df()
fig = viz.histogram(df, 'num1', bins=2, title='Hist')
assert fig.layout.title.text == 'Hist'
fig = viz.histogram(df, 'num1', density=True)
assert fig.data[0].histnorm == 'probability density'


def test_scatter_plot():
df = sample_df()
fig = viz.scatter_plot(df, 'num1', 'num2', color='cat', title='Scatter')
assert fig.layout.title.text == 'Scatter'
assert fig.data[0].marker.color is not None


def test_bar_and_pie_charts():
df = sample_df()
bar = viz.bar_chart(df, 'cat', 'num1')
pie = viz.pie_chart(df, names='cat', values='num1')
assert bar.data and pie.data


def test_box_and_violin():
df = sample_df()
box = viz.box_plot(df, x='cat', y='num1')
violin = viz.violin_plot(df, x='cat', y='num1')
assert box.data and violin.data


def test_heatmap():
df = sample_df()
fig = viz.heatmap(df, title='Heat')
assert fig.layout.title.text == 'Heat'


def test_export_figure(tmp_path):
df = sample_df()
fig = viz.bar_chart(df, 'cat', 'num1')
out = tmp_path / 'chart.html'
viz.export_figure(fig, out)
assert out.exists() and out.stat().st_size > 0
3 changes: 2 additions & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from . import config
from . import data
from . import eda
from . import viz

__all__ = ["config", "data", "eda"]
__all__ = ["config", "data", "eda", "viz"]
103 changes: 103 additions & 0 deletions utils/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Data visualization utilities."""

from __future__ import annotations

from pathlib import Path
from typing import Optional

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go


def histogram(
df: pd.DataFrame,
column: str,
*,
bins: int = 20,
density: bool = False,
title: Optional[str] = None,
) -> go.Figure:
"""Return a histogram or density plot for a column."""
histnorm = "probability density" if density else None
fig = px.histogram(df, x=column, nbins=bins, histnorm=histnorm, title=title)
return fig


def scatter_plot(
df: pd.DataFrame,
x: str,
y: str,
*,
color: Optional[str] = None,
title: Optional[str] = None,
) -> go.Figure:
"""Return a scatter plot."""
fig = px.scatter(df, x=x, y=y, color=color, title=title)
return fig


def bar_chart(
df: pd.DataFrame,
x: str,
y: str,
*,
title: Optional[str] = None,
) -> go.Figure:
"""Return a bar chart."""
fig = px.bar(df, x=x, y=y, title=title)
return fig


def pie_chart(
df: pd.DataFrame,
names: str,
values: str,
*,
title: Optional[str] = None,
) -> go.Figure:
"""Return a pie chart."""
fig = px.pie(df, names=names, values=values, title=title)
return fig


def box_plot(
df: pd.DataFrame,
x: str,
y: str,
*,
title: Optional[str] = None,
) -> go.Figure:
"""Return a box plot."""
fig = px.box(df, x=x, y=y, title=title)
return fig


def violin_plot(
df: pd.DataFrame,
x: str,
y: str,
*,
title: Optional[str] = None,
) -> go.Figure:
"""Return a violin plot."""
fig = px.violin(df, x=x, y=y, box=True, title=title)
return fig


def heatmap(
df: pd.DataFrame,
*,
columns: Optional[list[str]] = None,
title: Optional[str] = None,
) -> go.Figure:
"""Return a correlation heatmap for the given columns."""
cols = columns or df.select_dtypes(include="number").columns.tolist()
corr = df[cols].corr()
fig = px.imshow(corr, text_auto=True, title=title)
return fig


def export_figure(fig: go.Figure, path: Path) -> None:
"""Export a figure to an HTML file."""
fig.write_html(str(path))