# DataFrameのバリデーションの検証

In [1]:
import pandantic
import pandera
from pydantic.types import StrictFloat, StrictInt, StrictStr
import pandas as pd

# Pandanticを使った方法

https://wesselhuising.medium.com/validate-pandas-dataframes-using-pydantic-pandantic-4ec89709d5d

In [2]:
class Schema(pandantic.BaseModel):
    h3: StrictStr
    lat: StrictFloat
    lon: StrictFloat
    group_id: StrictInt

In [3]:
ok_df = pd.DataFrame([
    {
        "h3": "882f5a32ddfffff",
        "lat": 35.6812,
        "lon": 139.7671,
        "group_id": 100,
     }
])
Schema.parse_df(
    dataframe=ok_df,
    errors="filter",
)

Unnamed: 0,h3,lat,lon,group_id
0,882f5a32ddfffff,35.6812,139.7671,100


In [4]:
int2float_df = pd.DataFrame([
    {
        "h3": "882f5a32ddfffff",
        "lat": 35.6812,
        "lon": 139.7671,
        "group_id": 100.123,
     }
])
Schema.parse_df(
    dataframe=int2float_df,
    errors="filter",
)

1 validation error for Schema
group_id
  Input should be a valid integer [type=int_type, input_value=100.123, input_type=float]
    For further information visit https://errors.pydantic.dev/2.7/v/int_type


Unnamed: 0,h3,lat,lon,group_id


intをfloatにしたのはフィルターされた

In [5]:
float2int_df = pd.DataFrame([
    {
        "h3": "882f5a32ddfffff",
        "lat": 35.6812,
        "lon": 139,
        "group_id": 100,
     }
])
Schema.parse_df(
    dataframe=float2int_df,
    errors="filter",
)

Unnamed: 0,h3,lat,lon,group_id
0,882f5a32ddfffff,35.6812,139,100


floatをintにしたのは通った

In [6]:
missing_df = pd.DataFrame([
    {
        "h3": "882f5a32ddfffff",
        "lat": 35.6812,
        "group_id": 100,
     }
])
Schema.parse_df(
    dataframe=missing_df,
    errors="filter",
)

1 validation error for Schema
lon
  Field required [type=missing, input_value={'h3': '882f5a32ddfffff',...p_id': 100, '_index': 0}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.7/v/missing


Unnamed: 0,h3,lat,group_id


エラーになった

# Panderaを使った方法

https://pandera.readthedocs.io/en/stable/pydantic_integration.html

In [7]:
schema = pandera.DataFrameSchema({
    "h3": pandera.Column(pandera.String),
    "lat": pandera.Column(pandera.Float),
    "lon": pandera.Column(pandera.Float),
    "group_id": pandera.Column(pandera.Int),
})

In [8]:
schema(ok_df)

Unnamed: 0,h3,lat,lon,group_id
0,882f5a32ddfffff,35.6812,139.7671,100


In [9]:
schema(int2float_df)

SchemaError: expected series 'group_id' to have type int64, got float64

In [None]:
schema(float2int_df)

SchemaError: expected series 'lon' to have type float64, got int64

In [None]:
schema(missing_df)

SchemaError: column 'lon' not in dataframe. Columns in dataframe: ['h3', 'lat', 'group_id']

能力としてはpanderaの方が高そう

In [10]:
class UserDataFrmameSchema(pandera.DataFrameModel):
    h3: pandera.String
    lat: pandera.Float
    lon: pandera.Float
    group_id: pandera.Int

In [11]:
UserDataFrmameSchema.validate(ok_df)

Unnamed: 0,h3,lat,lon,group_id
0,882f5a32ddfffff,35.6812,139.7671,100


In [12]:
UserDataFrmameSchema.validate(int2float_df)

SchemaError: expected series 'group_id' to have type int64, got float64

In [14]:
df = UserDataFrmameSchema.validate(ok_df)
df.lat[0]


35.6812