Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions _posts/2024-04-07-unit-testing-datascience.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
layout: post
title: Enhancing Data Science Workflow with Unit Testing
subtitle: Enhancing Data Science Workflow with Unit Testing
cover-img: /assets/img/path.jpg
thumbnail-img: /assets/img/thumb.png
share-img: /assets/img/path.jpg
cover-img: /assets/img/pytest_1.jpg
thumbnail-img: /assets/img/pytest_thumb.png
share-img: /assets/img/pytest_2.jpg
gh-repo: arpithub/arpithub.github.io
gh-badge: [star, fork, follow]
tags: [datascience,testing,pytest,ml]
Expand All @@ -25,8 +25,7 @@ data-science-project/
├── src/
│ ├── preprocessing.py
│ └── model.py
├── tests/
│── tests/
│ ├── test_preprocessing.py
│ └── test_model.py
Expand Down Expand Up @@ -81,14 +80,19 @@ Now, create `tests/test_preprocessing.py` to write unit tests for the preprocess

```python
import pytest
from pathlib import Path
from src.preprocessing import load_iris_dataset, preprocess_data

data_dir = Path(__file__).parent.parent / 'data' # Navigate up to the project root
iris_path = data_dir / 'iris.csv'


@pytest.fixture
def iris_data():
return load_iris_dataset('data/iris.csv')
return load_iris_dataset(iris_path)

def test_load_iris_dataset():
df = load_iris_dataset('data/iris.csv')
df = load_iris_dataset(iris_path)
assert not df.empty

def test_preprocess_data(iris_data):
Expand All @@ -97,16 +101,16 @@ def test_preprocess_data(iris_data):
assert 'species' in preprocessed_df.columns

def test_missing_values():
df = iris_data()
df = load_iris_dataset(iris_path)
assert not df.isnull().values.any(), "Dataset contains missing values"

def test_no_duplicates():
df = iris_data()
df = load_iris_dataset(iris_path)
assert not df.duplicated().any(), "Dataset contains duplicate records"


def test_column_datatypes():
df = iris_data()
df = load_iris_dataset(iris_path)
expected_datatypes = {
'sepal length (cm)': 'float64',
'sepal width (cm)': 'float64',
Expand Down Expand Up @@ -145,12 +149,15 @@ Now, create `tests/test_model.py` to write unit tests for the model training and

```python
import pytest
from pathlib import Path
from src.model import train_and_evaluate_model
from src.preprocessing import load_iris_dataset, preprocess_data
data_dir = Path(__file__).parent.parent / 'data' # Navigate up to the project root
iris_path = data_dir / 'iris.csv'

@pytest.fixture
def preprocessed_iris_data():
df = load_iris_dataset('data/iris.csv')
df = load_iris_dataset(iris_path)
return preprocess_data(df)

def test_train_and_evaluate_model(preprocessed_iris_data):
Expand All @@ -159,7 +166,7 @@ def test_train_and_evaluate_model(preprocessed_iris_data):
```

#### Running the Tests
To run the tests using Pytest, navigate to the project directory and execute:
To run the tests using Pytest, navigate to the `tests` directory and execute:

```bash
pytest
Expand Down
Binary file added assets/img/pytest_1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/img/pytest_2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/img/pytest_thumb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.