In [None]:
!pip3 install -U sagemaker
!pip3 install polars

In [1]:
import xgboost as xgb
import pandas as pd
import sys
import os

In [2]:
model = xgb.XGBRegressor(verbosity=0, silent=True)
model.load_model(os.path.join("models", "model_xgboost.json"))

In [3]:
stations = pd.read_csv(os.path.join("staging_data", "station.csv")).rename(
    columns={"id": "stationid", "name": "station_name"}
)

locations = pd.read_csv(os.path.join("staging_data", "location.csv")).rename(
    columns={"id": "locationid", "name": "location_name"}
)

locationcategories = pd.read_csv(
    os.path.join("staging_data", "locationcategory.csv")
).rename(columns={"id": "locationcategoryid", "name": "locationcategory_name"})

station_relations = pd.read_csv(os.path.join("staging_data", "stationrelation.csv"))

station_relations = (
    station_relations.merge(stations, on="stationid", how="inner")
    .merge(locations, on="locationid", how="inner")
    .merge(locationcategories, on="locationcategoryid", how="inner")
)

In [4]:
def predict(locationcategory_name, location_name, start_date, duration, station_relations):
    locationcategory_name = locationcategory_name.lower()
    location_name = location_name.lower()
    start_date = pd.to_datetime(start_date)
    duration = int(duration)
    date_index = pd.date_range(start_date, start_date + pd.Timedelta(duration, "d"), freq="d")
    filtered_stations = station_relations.loc[(station_relations["locationcategory_name"].str.lower().str.contains(locationcategory_name)) &
                                                (station_relations["location_name"].str.lower().str.contains(location_name)),
                                                ["stationid", "latitude", "longitude", "elevation", "location_name"]].drop_duplicates()
    dates = pd.DataFrame(
        [
            {
                "dateid": int(
                    f"{date_index[i].year}{str(date_index[i].month).zfill(2)}{str(date_index[i].day).zfill(2)}"
                ),
                "date": date_index[i].__str__().split(" ")[0],
                "year": date_index[i].year - 2010,
                "quarter": date_index[i].quarter,
                "month": date_index[i].month,
                "week": date_index[i].week,
                "day_of_year": date_index[i].day_of_year,
                "is_leap_year": int(date_index[i].is_leap_year),
            }
            for i in range(len(date_index))
        ]
    )
    ref = dates.join(filtered_stations, how="cross")
    X_test = ref.drop(columns=["dateid", "date", "stationid", "location_name"])
    Y_pred = pd.DataFrame((model.predict(X_test) / 10), columns=["pred_tmax", "pred_tmin", "pred_prcp", "pred_snow", "pred_snwd"])
    ref = ref.drop(columns=["dateid", "year", "quarter", "month", "week", "day_of_year", "is_leap_year", "stationid", "latitude", "longitude", "elevation"])
    out = pd.concat([ref,Y_pred], axis=1, ignore_index=True)
    out.columns = ["Date", "Location", "TMAX (C)", "TMIN (C)", "PRCP (cm)", "SNOW (cm)", "SNWD (cm)"]
    return out.groupby(["Location", "Date"]).mean().reset_index()

In [5]:
locationcategory_name = "city"
location_name = "hyderabad"
start_date = "2023-12-04"
duration = "7"
predict(locationcategory_name, location_name, start_date, duration, station_relations)

  if is_sparse(dtype):
  elif is_categorical_dtype(dtype) and enable_categorical:
  if is_categorical_dtype(dtype)
  return is_int or is_bool or is_float or is_categorical_dtype(dtype)


Unnamed: 0,Location,Date,TMAX (C),TMIN (C),PRCP (cm),SNOW (cm),SNWD (cm)
0,"Hyderabad, IN",2023-12-04,22.1364,15.782339,0.431047,-0.028488,0.157913
1,"Hyderabad, IN",2023-12-05,22.1364,15.782339,0.431047,-0.028488,0.157913
2,"Hyderabad, IN",2023-12-06,22.1364,15.333679,0.801259,-0.008641,0.20172
3,"Hyderabad, IN",2023-12-07,22.1364,15.333679,0.801259,-0.008641,0.20172
4,"Hyderabad, IN",2023-12-08,22.1364,15.333679,0.801259,-0.008641,0.20172
5,"Hyderabad, IN",2023-12-09,22.1364,15.333679,0.833513,0.006329,0.20172
6,"Hyderabad, IN",2023-12-10,22.1364,15.333679,0.771845,0.006329,0.230897
7,"Hyderabad, IN",2023-12-11,22.1364,15.20947,0.63871,0.014953,0.121719
