In [None]:
!pip install plotly pandas numpy scikit-learn matplotlib seaborn wandb

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import seaborn as sns
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from sklearn.model_selection import train_test_split
import wandb
import random
import plotly.express as px

In [None]:
!wandb login 85cc7d382dbc235985a5dc344f9d428e2c61df94

In [None]:
# Data collection
data = {
    'Material_1': [
        'Steel', 'Copper', 'Aluminum', 'Aluminum', 'Brass', 'Cadmium', 'Cadmium',
        'Cast Iron', 'Chromium', 'Copper', 'Copper', 'Copper', 'Glass', 'Glass',
        'Glass', 'Graphite', 'Graphite', 'Nickel', 'Nickel', 'Nylon', 'Plexiglass',
        'Plexiglass', 'Polystyrene', 'Polystyrene', 'Rubber', 'Rubber', 'Brass',
        'Steel', 'Teflon', 'Teflon', 'Wood', 'Wood', 'Wood', 'Wood', 'Zinc', 'Zinc'
    ],
    'E_1': [
        200, 133, 69, 69, 115, 64, 64, 170, 248, 133, 133, 133, 60, 60, 60, 20,
        20, 170, 170, 4, 3.3, 3.3, 2.5, 2.5, 0.01, 0.01, 115, 200, 0.5, 0.5, 10,
        10, 10, 10, 82.7, 82.7
    ],
    'nu_1': [
        0.25, 0.35, 0.33, 0.33, 0.34, 0.31, 0.31, 0.26, 0.31, 0.35, 0.35, 0.35,
        0.25, 0.25, 0.25, 0.2, 0.2, 0.31, 0.31, 0.39, 0.37, 0.37, 0.4, 0.4, 0.47,
        0.47, 0.34, 0.25, 0.47, 0.47, 0.35, 0.35, 0.35, 0.35, 0.25, 0.25
    ],
    'rho_1': [
        8000, 8940, 2700, 2700, 8730, 8650, 8650, 7200, 7190, 8940, 8940, 8940,
        2400, 2400, 2400, 2050, 2050, 8900, 8900, 1130, 1190, 1190, 1040, 1040,
        2300, 2300, 8730, 8000, 2200, 2200, 750, 750, 750, 750, 7120, 7120
    ],
    'Material_2': [
        'Steel', 'Copper', 'Aluminum', 'Steel', 'Cast Iron', 'Cadmium', 'Steel',
        'Cast Iron', 'Chromium', 'Cast Iron', 'Copper', 'Steel', 'Glass', 'Steel',
        'Nickel', 'Graphite', 'Steel', 'Nickel', 'Steel', 'Nylon', 'Plexiglass',
        'Steel', 'Polystyrene', 'Steel', 'Asphalt', 'Concrete', 'Steel', 'Cast Iron',
        'Steel', 'Teflon', 'Wood', 'Copper', 'Steel', 'Concrete', 'Zinc', 'Cast Iron'
    ],
    'E_2': [
        200, 133, 69, 200, 170, 64, 200, 170, 248, 170, 133, 200, 60, 200, 170, 20,
        200, 170, 200, 4, 3.3, 200, 2.5, 200, 3, 17, 200, 170, 200, 0.5, 10, 133,
        200, 17, 82.7, 170
    ],
    'nu_2': [
        0.25, 0.35, 0.33, 0.25, 0.26, 0.31, 0.25, 0.26, 0.31, 0.26, 0.35, 0.25,
        0.25, 0.25, 0.31, 0.2, 0.25, 0.31, 0.25, 0.39, 0.37, 0.25, 0.4, 0.25, 0.35,
        0.17, 0.25, 0.26, 0.25, 0.47, 0.35, 0.35, 0.25, 0.17, 0.25, 0.26
    ],
    'rho_2': [
        8000, 8940, 2700, 8000, 7200, 8650, 8000, 7200, 7190, 7200, 8940, 8000,
        2400, 8000, 8900, 2050, 8000, 8900, 8000, 1130, 1190, 8000, 1040, 8000,
        2500, 2400, 8000, 7200, 8000, 2200, 750, 8940, 8000, 2400, 7120, 7200
    ],
    'mu_static': [
        0.78, 1.21, 1.1, 0.61, 0.4, 0.5, 0.6, 1.1, 0.41, 1.05, 1.0, 0.53, 0.95,
        0.6, 0.78, 0.1, 0.1, 0.8, 0.7, 0.2, 0.8, 0.45, 0.5, 0.3, 0.6, 0.9, 0.35,
        0.4, 0.04, 0.04, 0.3, 0.4, 0.45, 0.62, 0.6, 0.85
    ]
}


data2_incline_plane = {
    "Material_1": [
        "Silver",
        "Silver",
        "Silver",
        "Silver",
        "Aluminium",
        "Aluminium",
        "Gold",
        "Gold",
        "Cadmium",
        "Cadmium",
        "Cobalt",
        "Cobalt",
        "Chromium",
        "Chromium",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Indium",
        "Manganese",
        "Molybdenum",
        "Molybdenum",
        "Niobium",
        "Nickel",
        "Nickel",
        "Nickel",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Platinum",
        "Platinum",
        "Tin",
        "Tin",
        "Titanium",
        "Titanium",
        "Tungsten",
        "Tungsten",
        "Tungsten",
        "Zinc",
        "Zinc",
        "Zinc"
    ],
    "E_1": [
        83.4,   # Silver
        83.4,   # Silver
        83.4,   # Silver
        83.4,   # Silver
        69.0,    # Aluminium
        69.0,    # Aluminium
        79.0,    # Gold
        79.0,    # Gold
        28.0,    # Cadmium
        28.0,    # Cadmium
        211.0,   # Cobalt
        279.0,   # Chromium
        279.0,   # Chromium
        279.0,   # Chromium
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        82.0,    # Indium
        130.0,   # Manganese
        329.0,   # Molybdenum
        329.0,   # Molybdenum
        105.0,   # Niobium
        207.0,   # Nickel
        207.0,   # Nickel
        207.0,   # Nickel
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        168.0,   # Platinum
        168.0,   # Platinum
        50.0,    # Tin
        50.0,    # Tin
        116.0,   # Titanium
        116.0,   # Titanium
        411.0,   # Tungsten
        411.0,   # Tungsten
        411.0,   # Tungsten
        108.0,   # Zinc
        108.0,    # Zinc
        108.0    # Zinc
    ],
    "nu_1": [
        0.37,  # Silver
        0.37,  # Silver
        0.37,  # Silver
        0.37,  # Silver
        0.33,  # Aluminium
        0.33,  # Aluminium
        0.42,  # Gold
        0.42,  # Gold
        0.30,  # Cadmium
        0.30,  # Cadmium
        0.31,  # Cobalt
        0.31,  # Cobalt
        0.21,  # Chromium
        0.21,  # Chromium
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.45,  # Indium
        0.21,  # Manganese
        0.31,  # Molybdenum
        0.31,  # Molybdenum
        0.40,  # Niobium
        0.31,  # Nickel
        0.31,  # Nickel
        0.31,  # Nickel
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.38,  # Platinum
        0.38,  # Platinum
        0.36,  # Tin
        0.36,  # Tin
        0.34,  # Titanium
        0.34,  # Titanium
        0.28,  # Tungsten
        0.28,  # Tungsten
        0.28,  # Tungsten
        0.25,  # Zinc
        0.25,   # Zinc
        0.25   # Zinc
    ],
    "rho_1": [
        10490,  # Silver
        10490,  # Silver
        10490,  # Silver
        10490,  # Silver
        2700,   # Aluminium
        2700,   # Aluminium
        19320,  # Gold
        19320,  # Gold
        8650,   # Cadmium
        8650,   # Cadmium
        8900,   # Cobalt
        8900,   # Cobalt
        7190,   # Chromium
        7190,   # Chromium
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7310,   # Indium
        7440,   # Manganese
        10220,  # Molybdenum
        10220,  # Molybdenum
        8570,   # Niobium
        8908,   # Nickel
        8908,   # Nickel
        8908,   # Nickel
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        21450,  # Platinum
        21450,  # Platinum
        7310,   # Tin
        7310,   # Tin
        4507,   # Titanium
        4507,   # Titanium
        19300,  # Tungsten
        19300,  # Tungsten
        19300,  # Tungsten
        7140,   # Zinc
        7140,    # Zinc
        7140,    # Zinc
    ],
    "Material_2": [
        "Silver",
        "Gold",
        "Copper",
        "Iron",
        "Aluminium",
        "Titanium",
        "Silver",
        "Gold",
        "Cadmium",
        "Iron",
        "Cobalt",
        "Chromium",
        "Cobalt",
        "Chromium",
        "Cobalt",
        "Chromium",
        "Copper",
        "Iron",
        "Nickel",
        "Zinc",
        "Cobalt",
        "Chromium",
        "Iron",
        "Manganese",
        "Molybdenum",
        "Titanium",
        "Tungsten",
        "Zinc",
        "Indium",
        "Manganese",
        "Iron",
        "Molybdenum",
        "Niobium",
        "Chromium",
        "Nickel",
        "Platinum",
        "Silver",
        "Gold",
        "Copper",
        "Chromium",
        "Iron",
        "Lead",
        "Nickel",
        "Platinum",
        "Iron",
        "Tin",
        "Aluminium",
        "Titanium",
        "Copper",
        "Iron",
        "Tungsten",
        "Copper",
        "Iron",
        "Zinc"
    ],
    "E_2": [
        83.0,   # Silver
        79.0,   # Gold
        110.0,  # Copper
        211.0,  # Iron
        69.0,   # Aluminium
        116.0,  # Titanium
        83.0,   # Silver
        79.0,   # Gold
        50.0,   # Cadmium
        211.0,  # Iron
        209.0,  # Cobalt
        279.0,  # Chromium
        209.0,  # Cobalt
        279.0,  # Chromium
        209.0,  # Cobalt
        279.0,  # Chromium
        110.0,  # Copper
        211.0,  # Iron
        207.0,  # Nickel
        108.0,  # Zinc
        209.0,  # Cobalt
        279.0,  # Chromium
        211.0,  # Iron
        200.0,  # Manganese
        329.0,  # Molybdenum
        116.0,  # Titanium
        411.0,  # Tungsten
        108.0,  # Zinc
        11.0,   # Indium
        200.0,  # Manganese
        211.0,  # Iron
        329.0,  # Molybdenum
        105.0,  # Niobium
        279.0,  # Chromium
        207.0,  # Nickel
        168.0,  # Platinum
        83.0,   # Silver
        79.0,   # Gold
        110.0,  # Copper
        279.0,  # Chromium
        211.0,  # Iron
        16.0,   # Lead
        207.0,  # Nickel
        168.0,  # Platinum
        211.0,  # Iron
        50.0,   # Tin
        69.0,   # Aluminium
        116.0,  # Titanium
        110.0,  # Copper
        211.0,  # Iron
        411.0,  # Tungsten
        110.0,  # Copper
        211.0,  # Iron
        108.0   # Zinc
    ],
    "nu_2": [
        0.37,  # Silver
        0.42,  # Gold
        0.34,  # Copper
        0.29,  # Iron
        0.33,  # Aluminium
        0.34,  # Titanium
        0.37,  # Silver
        0.42,  # Gold
        0.30,  # Cadmium
        0.29,  # Iron
        0.31,  # Cobalt
        0.21,  # Chromium
        0.31,  # Cobalt
        0.21,  # Chromium
        0.31,  # Cobalt
        0.21,  # Chromium
        0.34,  # Copper
        0.29,  # Iron
        0.31,  # Nickel
        0.25,  # Zinc
        0.31,  # Cobalt
        0.21,  # Chromium
        0.29,  # Iron
        0.21,  # Manganese
        0.31,  # Molybdenum
        0.34,  # Titanium
        0.28,  # Tungsten
        0.25,  # Zinc
        0.45,  # Indium
        0.21,  # Manganese
        0.29,  # Iron
        0.31,  # Molybdenum
        0.40,  # Niobium
        0.21,  # Chromium
        0.31,  # Nickel
        0.38,  # Platinum
        0.37,  # Silver
        0.42,  # Gold
        0.34,  # Copper
        0.21,  # Chromium
        0.29,  # Iron
        0.44,  # Lead
        0.31,  # Nickel
        0.38,  # Platinum
        0.29,  # Iron
        0.36,  # Tin
        0.33,  # Aluminium
        0.34,  # Titanium
        0.34,  # Copper
        0.29,  # Iron
        0.28,  # Tungsten
        0.34,  # Copper
        0.29,  # Iron
        0.25   # Zinc
    ],
    "rho_2": [
        10490,  # Silver
        19320,  # Gold
        8960,   # Copper
        7870,   # Iron
        2700,   # Aluminium
        4507,   # Titanium
        10490,  # Silver
        19320,  # Gold
        8650,   # Cadmium
        7870,   # Iron
        8900,   # Cobalt
        7190,   # Chromium
        8900,   # Cobalt
        7190,   # Chromium
        8900,   # Cobalt
        7190,   # Chromium
        8960,   # Copper
        7870,   # Iron
        8908,   # Nickel
        7140,   # Zinc
        8900,   # Cobalt
        7190,   # Chromium
        7870,   # Iron
        7440,   # Manganese
        10220,  # Molybdenum
        4507,   # Titanium
        19300,  # Tungsten
        7140,   # Zinc
        7310,   # Indium
        7440,   # Manganese
        7870,   # Iron
        10220,  # Molybdenum
        8570,   # Niobium
        7190,   # Chromium
        8908,   # Nickel
        21450,  # Platinum
        10490,  # Silver
        19320,  # Gold
        8960,   # Copper
        7190,   # Chromium
        7870,   # Iron
        11340,  # Lead
        8908,   # Nickel
        21450,  # Platinum
        7870,   # Iron
        7310,   # Tin
        2700,   # Aluminium
        4507,   # Titanium
        8960,   # Copper
        7870,   # Iron
        19300,  # Tungsten
        8960,   # Copper
        7870,   # Iron
        7140    # Zinc
    ],
    "mu_static": [
        0.5,
        0.53,
        0.48,
        0.49,
        0.57,
        0.54,
        0.53,
        0.49,
        0.79,
        0.52,
        0.56,
        0.41,
        0.41,
        0.46,
        0.44,
        0.46,
        0.55,
        0.50,
        0.49,
        0.56,
        0.41,
        0.48,
        0.51,
        0.51,
        0.46,
        0.49,
        0.47,
        0.55,
        1.46,
        0.69,
        0.46,
        0.44,
        0.46,
        0.59,
        0.50,
        0.64,
        0.73,
        0.61,
        0.55,
        0.53,
        0.54,
        0.90,
        0.64,
        0.55,
        0.55,
        0.74,
        0.54,
        0.55,
        0.41,
        0.47,
        0.51,
        0.56,
        0.55,
        0.75
    ]
}

# Create DataFrames
df_reference = pd.DataFrame(data)
df_incline_plane = pd.DataFrame(data2_incline_plane)

# Combine both datasets
df_combined = pd.concat([df_reference, df_incline_plane], ignore_index=True)

In [None]:
wandb.init(
    project="FrictionAI",
    config={
        "model": "RandomForestRegressor",
        "n_estimators": 100,
        "max_depth": 10,
        "random_state": 42,
        "feature_set": "all_features",
        "dataset": {
            "reference": {
                "num_samples": len(df_reference['mu_static']),
                "features": ['E_1', 'E_2', 'nu_1', 'nu_2', 'rho_1', 'rho_2'],
                "description": "Reference dataset with material properties and friction coefficients."
            },
            "incline_plane": {
                "num_samples": len(df_incline_plane['mu_static']),
                "features": ['E_1', 'E_2', 'nu_1', 'nu_2', 'rho_1', 'rho_2'],
                "description": "Incline plane dataset with additional material interactions."
            }
        }
    }
)

In [None]:
# Feature Engineering
df_combined['E_diff'] = df_combined['E_1'] - df_combined['E_2']
df_combined['nu_diff'] = df_combined['nu_1'] - df_combined['nu_2']
df_combined['rho_diff'] = df_combined['rho_1'] - df_combined['rho_2']
df_combined['E_ratio'] = df_combined['E_1'] / df_combined['E_2']
df_combined['nu_ratio'] = df_combined['nu_1'] / df_combined['nu_2']
df_combined['rho_ratio'] = df_combined['rho_1'] / df_combined['rho_2']

# Interaction features
df_combined['E_nu_interaction_1'] = df_combined['E_1'] * df_combined['nu_1']
df_combined['E_rho_interaction_1'] = df_combined['E_1'] * df_combined['rho_1']
df_combined['nu_rho_interaction_1'] = df_combined['nu_1'] * df_combined['rho_1']

df_combined['E_nu_interaction_2'] = df_combined['E_2'] * df_combined['nu_2']
df_combined['E_rho_interaction_2'] = df_combined['E_2'] * df_combined['rho_2']
df_combined['nu_rho_interaction_2'] = df_combined['nu_2'] * df_combined['rho_2']

In [None]:
# Define features and target
feature_columns = [
    'E_1', 'E_2', 'nu_1', 'nu_2', 'rho_1', 'rho_2',
    'E_diff', 'nu_diff', 'rho_diff',
    'E_ratio', 'nu_ratio', 'rho_ratio',
    'E_nu_interaction_1', 'E_rho_interaction_1', 'nu_rho_interaction_1',
    'E_nu_interaction_2', 'E_rho_interaction_2', 'nu_rho_interaction_2'
]
X = df_combined[feature_columns]
y = df_combined['mu_static']


In [None]:
# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

In [None]:
# Initialize RandomForestRegressor
rf_model = RandomForestRegressor(
    n_estimators=wandb.config.n_estimators,
    max_depth=wandb.config.max_depth,
    random_state=wandb.config.random_state
)

# Train the model
rf_model.fit(X_train, y_train)


In [None]:
# Predictions on test data
y_pred = rf_model.predict(X_test)

# Compute error metrics
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

# Log metrics to W&B
wandb.log({
    "RandomForestRegressor": {
        "RMSE": rmse,
        "MAE": mae,
        "R2_Score": r2
    }
})

print(f"RandomForestRegressor Performance:")
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R2 Score: {r2:.4f}")

In [None]:
# Feature Importance
importances = rf_model.feature_importances_
indices = np.argsort(importances)[::-1]
feature_names = X.columns

# Plot feature importances
plt.figure(figsize=(12, 8))
sns.barplot(x=importances[indices], y=feature_names[indices], palette='viridis')
plt.title('Feature Importances from RandomForestRegressor')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.tight_layout()
plt.show()

# Log the plot to W&B
wandb.log({"Feature Importances": plt})


In [None]:
import plotly.express as px

# For simplicity, let's visualize the top two most important features against the actual vs predicted values
top_features = feature_names[indices[:2]]

# Create a DataFrame for visualization
df_vis = pd.DataFrame({
    top_features[0]: X_test[:, indices[0]],
    top_features[1]: X_test[:, indices[1]],
    'Actual mu_static': y_test,
    'Predicted mu_static': y_pred
})

# Interactive Scatter Plot
fig = px.scatter_3d(
    df_vis,
    x=top_features[0],
    y=top_features[1],
    z='Actual mu_static',
    color='Predicted mu_static',
    title='Actual vs Predicted Coefficient of Friction',
    labels={
        top_features[0]: top_features[0],
        top_features[1]: top_features[1],
        'Actual mu_static': 'Actual Coefficient of Friction',
        'Predicted mu_static': 'Predicted Coefficient of Friction'
    },
    color_continuous_scale='Viridis'
)

fig.show()

# Log the plot to W&B
wandb.log({"Actual vs Predicted mu_static": fig})


In [None]:
# Prepare new data
df_new = df_incline_plane.copy()

# Feature Engineering for new data
df_new['E_diff'] = df_new['E_1'] - df_new['E_2']
df_new['nu_diff'] = df_new['nu_1'] - df_new['nu_2']
df_new['rho_diff'] = df_new['rho_1'] - df_new['rho_2']
df_new['E_ratio'] = df_new['E_1'] / df_new['E_2']
df_new['nu_ratio'] = df_new['nu_1'] / df_new['nu_2']
df_new['rho_ratio'] = df_new['rho_1'] / df_new['rho_2']

# Interaction features
df_new['E_nu_interaction_1'] = df_new['E_1'] * df_new['nu_1']
df_new['E_rho_interaction_1'] = df_new['E_1'] * df_new['rho_1']
df_new['nu_rho_interaction_1'] = df_new['nu_1'] * df_new['rho_1']

df_new['E_nu_interaction_2'] = df_new['E_2'] * df_new['nu_2']
df_new['E_rho_interaction_2'] = df_new['E_2'] * df_new['rho_2']
df_new['nu_rho_interaction_2'] = df_new['nu_2'] * df_new['rho_2']

# Select features
X_new = df_new[feature_columns]
X_new_scaled = scaler.transform(X_new)

# Make predictions
mu_pred_new = rf_model.predict(X_new_scaled)

# Add predictions to the DataFrame
df_new['mu_pred'] = mu_pred_new
df_new['error'] = abs(df_new['mu_pred'] - df_new['mu_static'])

# Log each prediction to W&B
for idx, row in df_new.iterrows():
    wandb.log({
        "Material": row['Material_1'],
        "Actual mu_static": row['mu_static'],
        "Predicted mu_static": row['mu_pred'],
        "Error": row['error']
    })

# Display a few predictions
df_new[['Material_1', 'mu_static', 'mu_pred', 'error']].head()


In [None]:
# Finish the W&B run
wandb.finish()
