In [1]:
import sklearn
import joblib
import session_info
import pandas

from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import make_column_transformer
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor

In [2]:
session_info.show()

In [3]:
sklearn.set_config(display='diagram')

## Data Collection

In [4]:
dataset = fetch_openml(data_id=43355, as_frame=True, parser='auto')

In [5]:
dataset.keys()

dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

In [6]:
print(dataset.DESCR)

Context
Buying a diamond can be frustrating and expensive.  
It inspired me to create this dataset of 119K natural and lab-created diamonds from brilliantearth.com to demystify the value of the 4 Cs  cut, color, clarity, carat.
This data was scraped using DiamondScraper.
Content



Attribute
Description
Data Type




id
Diamond identification number provided by Brilliant Earth
int


url
URL for the diamond details page
string


shape
External geometric appearance of a diamond
string/categorical


price
Price in U.S. dollars
int


carat
Unit of measurement used to describe the weight of a diamond
float


cut
Facets, symmetry, and reflective qualities of a diamond
string/categorical


color
Natural color or lack of color visible within a diamond, based on the GIA grade scale
string/categorical


clarity
Visibility of natural microscopic inclusions and imperfections within a diamond
string/categorical


report
Diamond certificate or grading report provided by an independent gemology lab
s

In [7]:
print(dataset.data)

              id                                                url    shape  \
0       10086429  https://www.brilliantearth.com//loose-diamonds...    Round   
1       10016334  https://www.brilliantearth.com//loose-diamonds...  Emerald   
2        9947216  https://www.brilliantearth.com//loose-diamonds...  Emerald   
3       10083437  https://www.brilliantearth.com//loose-diamonds...    Round   
4        9946136  https://www.brilliantearth.com//loose-diamonds...  Emerald   
...          ...                                                ...      ...   
119302  10081678  https://www.brilliantearth.com//lab-diamonds-s...    Round   
119303   9521564  https://www.brilliantearth.com//lab-diamonds-s...  Cushion   
119304   9896730  https://www.brilliantearth.com//lab-diamonds-s...  Cushion   
119305   9756570  https://www.brilliantearth.com//lab-diamonds-s...     Oval   
119306   9293400  https://www.brilliantearth.com//lab-diamonds-s...  Cushion   

         price  carat            cut co

In [8]:
diamond_prices = dataset.data

In [9]:
diamond_prices

Unnamed: 0,id,url,shape,price,carat,cut,color,clarity,report,type,date_fetched
0,10086429,https://www.brilliantearth.com//loose-diamonds...,Round,400,0.30,'Very Good',J,SI2,GIA,natural,'2020-11-29 12-26 PM'
1,10016334,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.31,Ideal,I,SI1,GIA,natural,'2020-11-29 12-26 PM'
2,9947216,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.30,Ideal,I,VS2,GIA,natural,'2020-11-29 12-26 PM'
3,10083437,https://www.brilliantearth.com//loose-diamonds...,Round,400,0.30,Ideal,I,SI2,GIA,natural,'2020-11-29 12-26 PM'
4,9946136,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.30,Ideal,I,SI1,GIA,natural,'2020-11-29 12-26 PM'
...,...,...,...,...,...,...,...,...,...,...,...
119302,10081678,https://www.brilliantearth.com//lab-diamonds-s...,Round,99040,5.71,'Super Ideal',D,VVS2,GCAL,lab,'2020-11-29 12-26 PM'
119303,9521564,https://www.brilliantearth.com//lab-diamonds-s...,Cushion,107330,15.32,'Very Good',G,SI2,IGI,lab,'2020-11-29 12-26 PM'
119304,9896730,https://www.brilliantearth.com//lab-diamonds-s...,Cushion,110110,10.05,Ideal,D,SI2,IGI,lab,'2020-11-29 12-26 PM'
119305,9756570,https://www.brilliantearth.com//lab-diamonds-s...,Oval,126030,10.33,Fair,D,VS2,IGI,lab,'2020-11-29 12-26 PM'


In [10]:
diamond_prices.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 119307 entries, 0 to 119306
Data columns (total 11 columns):
 #   Column        Non-Null Count   Dtype  
---  ------        --------------   -----  
 0   id            119307 non-null  int64  
 1   url           119307 non-null  object 
 2   shape         119307 non-null  object 
 3   price         119307 non-null  int64  
 4   carat         119307 non-null  float64
 5   cut           119307 non-null  object 
 6   color         119307 non-null  object 
 7   clarity       119307 non-null  object 
 8   report        119307 non-null  object 
 9   type          119307 non-null  object 
 10  date_fetched  119307 non-null  object 
dtypes: float64(1), int64(2), object(8)
memory usage: 10.0+ MB


In [11]:
target = ['price']
numeric_features = ['carat']
categorical_features = ['shape', 'cut', 'color', 'clarity', 'report', 'type']

## EDA

In [12]:
diamond_prices.head()

Unnamed: 0,id,url,shape,price,carat,cut,color,clarity,report,type,date_fetched
0,10086429,https://www.brilliantearth.com//loose-diamonds...,Round,400,0.3,'Very Good',J,SI2,GIA,natural,'2020-11-29 12-26 PM'
1,10016334,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.31,Ideal,I,SI1,GIA,natural,'2020-11-29 12-26 PM'
2,9947216,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.3,Ideal,I,VS2,GIA,natural,'2020-11-29 12-26 PM'
3,10083437,https://www.brilliantearth.com//loose-diamonds...,Round,400,0.3,Ideal,I,SI2,GIA,natural,'2020-11-29 12-26 PM'
4,9946136,https://www.brilliantearth.com//loose-diamonds...,Emerald,400,0.3,Ideal,I,SI1,GIA,natural,'2020-11-29 12-26 PM'


In [13]:
diamond_prices.loc[:, target].describe()

Unnamed: 0,price
count,119307.0
mean,3286.843
std,9114.695
min,270.0
25%,900.0
50%,1770.0
75%,3490.0
max,1348720.0


In [14]:
diamond_prices.loc[:, numeric_features].describe()

Unnamed: 0,carat
count,119307.0
mean,0.884169
std,0.671141
min,0.25
25%,0.4
50%,0.7
75%,1.1
max,15.32


In [15]:
diamond_prices.loc[:, categorical_features].describe()

Unnamed: 0,shape,cut,color,clarity,report,type
count,119307,119307,119307,119307,119307,119307
unique,10,5,7,8,4,2
top,Round,'Super Ideal',E,VS1,GIA,natural
freq,76080,55244,24730,27259,68782,70313


## Model Building

In [16]:
X = diamond_prices.drop(columns=target)
y = diamond_prices[target]

In [17]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=20130810)

In [18]:
X_train.shape

(95445, 10)

In [19]:
X_test.shape

(23862, 10)

In [20]:
y_train.shape

(95445, 1)

In [21]:
y_test.shape

(23862, 1)

In [23]:
import os
from pathlib import Path

In [24]:
path = Path.cwd()
path

WindowsPath('d:/ANNACONDA/projects/Projects_Challenges/personal_projects/mlops')

In [25]:
path.exists()

True

In [29]:
data_folder = path/'data'
data_folder.mkdir(parents=True, exist_ok=True)

In [30]:
X_train.to_csv(data_folder/'20240525_training_features.csv', index=False)
y_train.to_csv(data_folder/'20240525_training_target.csv', index=False)

In [31]:
numeric_features

['carat']

In [33]:
preprocessor = make_column_transformer((StandardScaler(), numeric_features))
preprocessor

In [34]:
model_pipeline = make_pipeline(preprocessor, DecisionTreeRegressor())
model_pipeline

In [35]:
model_pipeline.fit(X_train, y_train)

In [36]:
X_test.iloc[:1]

Unnamed: 0,id,url,shape,carat,cut,color,clarity,report,type,date_fetched
107681,10060752,https://www.brilliantearth.com//lab-diamonds-s...,Oval,2.0,Ideal,J,SI2,IGI,lab,'2020-11-29 12-26 PM'


In [37]:
model_pipeline.predict(X_test.iloc[:1])

array([7468.89838557])

In [38]:
y_test.iloc[:1]

Unnamed: 0,price
107681,3770


In [44]:
(path/"models").mkdir(exist_ok=True)

In [45]:
joblib.dump(model_pipeline, 'models/models-v1.joblib')

['models/models-v1.joblib']

In [46]:
saved_model = joblib.load('models/models-v1.joblib')
saved_model

In [47]:
saved_model.predict(X_test.iloc[:1])

array([7468.89838557])

In [49]:
import pandas as pd

In [50]:
samples = {'carat':0.02}
pd.DataFrame([samples])

Unnamed: 0,carat
0,0.02


In [51]:
saved_model.predict(pd.DataFrame([samples]))

array([634.87256372])

In [52]:
categorical_features_model2 = ['type']

In [53]:
preprocessor = make_column_transformer(
    (StandardScaler(), numeric_features),
    (OneHotEncoder(handle_unknown='ignore'), categorical_features_model2)
)

In [54]:
preprocessor

In [55]:
model_pipeline = make_pipeline(preprocessor, DecisionTreeRegressor())
model_pipeline

In [62]:
model_pipeline.fit(X_train, y_train)

In [63]:
joblib.dump(model_pipeline, 'models/model-v2.joblib')

['models/model-v2.joblib']

In [64]:
saved_model = joblib.load('models/model-v2.joblib')

In [65]:
saved_model

In [70]:
diamond_prices.type.value_counts()

type
natural    70313
lab        48994
Name: count, dtype: int64

In [71]:
samples = {'carat':0.02, 'type':'Z'}

In [72]:
pd.DataFrame([samples])

Unnamed: 0,carat,type
0,0.02,Z


In [73]:
saved_model.predict(pd.DataFrame([samples]))

array([305.])

In [76]:
saved_model

In [75]:
saved_model[:-1]

In [74]:
saved_model[:-1].transform(pd.DataFrame([samples]))

array([[-1.28466235,  0.        ,  0.        ]])

In [77]:
preprocessor = make_column_transformer(
    (StandardScaler(), numeric_features),
    (OneHotEncoder(handle_unknown='ignore'), categorical_features)
)

In [78]:
model_pipeline = make_pipeline(preprocessor, DecisionTreeRegressor())

In [79]:
model_pipeline

In [80]:
model_pipeline.fit(X_train, y_train)

In [81]:
joblib.dump(model_pipeline, 'models/model-v3.joblib')

['models/model-v3.joblib']

In [82]:
saved_model = joblib.load('models/model-v3.joblib')
saved_model

In [85]:
saved_model_files = ['models-v1.joblib', 'model-v2.joblib', 'model-v3.joblib', ]

In [87]:
print("Mean Absolute Error (MAE) on test set")
for saved_model_file in saved_model_files:
    saved_model = joblib.load('models/'+saved_model_file)
    print(saved_model_file)
    print(sklearn.metrics.mean_absolute_error(y_test, saved_model.predict(X_test)))

Mean Absolute Error (MAE) on test set
models-v1.joblib
1373.360447353515
model-v2.joblib
843.5439501254651
model-v3.joblib
295.1115418505655


In [88]:
X_test.to_csv('data/20240525_test_features.csv', index=False)
y_test.to_csv('data/20240525_test_target.csv', index=False)