In [6]:
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import IsolationForest
import matplotlib.pyplot as plt
import seaborn as sns

# Load the dataset
file_path = '/Users/annelise/Documents/GitHub/Wine_tasting_KG/data_kaggle/'
wine_data = pd.read_csv(file_path + 'cleaned_wine_data_price_max=50.csv')
print(wine_data.shape)

# Segment the data into $10 price categories
wine_data['price_category'] = pd.cut(wine_data['price'], bins=[0, 10, 20, 30, 40, 50], labels=['0-10', '10-20', '20-30', '30-40', '40-50'])

# Calculate Z-score for ratings within each price category
wine_data['rating_zscore'] = wine_data.groupby('price_category')['points'].transform(lambda x: (x - x.mean()) / x.std())

# Fit a linear regression model (price vs rating)
X = wine_data[['price']]
y = wine_data['points']

model = LinearRegression()
model.fit(X, y)

# Calculate residuals (difference between predicted rating and actual rating)
wine_data['predicted_points'] = model.predict(X)
wine_data['residuals'] = wine_data['points'] - wine_data['predicted_points']

# Apply Isolation Forest for anomaly detection
isolation_forest = IsolationForest(contamination=0.05, random_state=42)
wine_data['anomaly'] = isolation_forest.fit_predict(wine_data[['price', 'points']])

# Refine the criteria: Z-score > 1 and residuals > 1 (strict filtering)
interesting_wines = wine_data[(wine_data['rating_zscore'] > 1) & (wine_data['residuals'] > 1) & (wine_data['anomaly'] == 1)]

# Display the refined interesting wines
print(f"Interesting wines: {interesting_wines.shape}")

# Visualization by price category (distribution of ratings and price for interesting wines)
plt.figure(figsize=(12, 6))

# Boxplot showing the spread of points by price category for interesting wines
sns.boxplot(x='price_category', y='points', data=interesting_wines)
plt.title('Distribution of Wine Ratings by Price Category (Interesting Wines)')
plt.ylabel('Wine Rating')
plt.xlabel('Price Category')
plt.xticks(rotation=45)
plt.show()

# Visualization of residuals vs price (price/quality overperformance)
plt.figure(figsize=(12, 6))
sns.scatterplot(x='price', y='residuals', hue='price_category', data=interesting_wines, palette='tab10')
plt.title('Residuals (Overperformance) vs Price by Price Category')
plt.ylabel('Residuals (Rating - Predicted)')
plt.xlabel('Price ($)')
plt.show()

# Optionally, save the filtered data
interesting_wines.to_csv(file_path, 'interesting_wines.csv', index=False)


(73637, 9)
Interesting wines: (10116, 14)


NameError: name 'plt' is not defined