In [None]:
!pip install torch torchvision torchaudio
!pip install plotly pandas numpy scikit-learn matplotlib seaborn wandb

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import wandb


In [None]:
# Data collection (as provided by the user)
data = {
    'Material_1': [
        'Steel', 'Copper', 'Aluminum', 'Aluminum', 'Brass', 'Cadmium', 'Cadmium',
        'Cast Iron', 'Chromium', 'Copper', 'Copper', 'Copper', 'Glass', 'Glass',
        'Glass', 'Graphite', 'Graphite', 'Nickel', 'Nickel', 'Nylon', 'Plexiglass',
        'Plexiglass', 'Polystyrene', 'Polystyrene', 'Rubber', 'Rubber', 'Brass',
        'Steel', 'Teflon', 'Teflon', 'Wood', 'Wood', 'Wood', 'Wood', 'Zinc', 'Zinc'
    ],
    'E_1': [
        200, 133, 69, 69, 115, 64, 64, 170, 248, 133, 133, 133, 60, 60, 60, 20,
        20, 170, 170, 4, 3.3, 3.3, 2.5, 2.5, 0.01, 0.01, 115, 200, 0.5, 0.5, 10,
        10, 10, 10, 82.7, 82.7
    ],
    'nu_1': [
        0.25, 0.35, 0.33, 0.33, 0.34, 0.31, 0.31, 0.26, 0.31, 0.35, 0.35, 0.35,
        0.25, 0.25, 0.25, 0.2, 0.2, 0.31, 0.31, 0.39, 0.37, 0.37, 0.4, 0.4, 0.47,
        0.47, 0.34, 0.25, 0.47, 0.47, 0.35, 0.35, 0.35, 0.35, 0.25, 0.25
    ],
    'rho_1': [
        8000, 8940, 2700, 2700, 8730, 8650, 8650, 7200, 7190, 8940, 8940, 8940,
        2400, 2400, 2400, 2050, 2050, 8900, 8900, 1130, 1190, 1190, 1040, 1040,
        2300, 2300, 8730, 8000, 2200, 2200, 750, 750, 750, 750, 7120, 7120
    ],
    'Material_2': [
        'Steel', 'Copper', 'Aluminum', 'Steel', 'Cast Iron', 'Cadmium', 'Steel',
        'Cast Iron', 'Chromium', 'Cast Iron', 'Copper', 'Steel', 'Glass', 'Steel',
        'Nickel', 'Graphite', 'Steel', 'Nickel', 'Steel', 'Nylon', 'Plexiglass',
        'Steel', 'Polystyrene', 'Steel', 'Asphalt', 'Concrete', 'Steel', 'Cast Iron',
        'Steel', 'Teflon', 'Wood', 'Copper', 'Steel', 'Concrete', 'Zinc', 'Cast Iron'
    ],
    'E_2': [
        200, 133, 69, 200, 170, 64, 200, 170, 248, 170, 133, 200, 60, 200, 170, 20,
        200, 170, 200, 4, 3.3, 200, 2.5, 200, 3, 17, 200, 170, 200, 0.5, 10, 133,
        200, 17, 82.7, 170
    ],
    'nu_2': [
        0.25, 0.35, 0.33, 0.25, 0.26, 0.31, 0.25, 0.26, 0.31, 0.26, 0.35, 0.25,
        0.25, 0.25, 0.31, 0.2, 0.25, 0.31, 0.25, 0.39, 0.37, 0.25, 0.4, 0.25, 0.35,
        0.17, 0.25, 0.26, 0.25, 0.47, 0.35, 0.35, 0.25, 0.17, 0.25, 0.26
    ],
    'rho_2': [
        8000, 8940, 2700, 8000, 7200, 8650, 8000, 7200, 7190, 7200, 8940, 8000,
        2400, 8000, 8900, 2050, 8000, 8900, 8000, 1130, 1190, 8000, 1040, 8000,
        2500, 2400, 8000, 7200, 8000, 2200, 750, 8940, 8000, 2400, 7120, 7200
    ],
    'mu_static': [
        0.5,
        0.53,
        0.48,
        0.49,
        0.57,
        0.54,
        0.53,
        0.49,
        0.79,
        0.52,
        0.56,
        0.41,
        0.41,
        0.46,
        0.44,
        0.46,
        0.55,
        0.50,
        0.49,
        0.56,
        0.41,
        0.48,
        0.51,
        0.51,
        0.46,
        0.49,
        0.47,
        0.55,
        1.46,
        0.69,
        0.46,
        0.44,
        0.46,
        0.59,
        0.50,
        0.64,
        0.73,
        0.61,
        0.55,
        0.53,
        0.54,
        0.90,
        0.64,
        0.55,
        0.55,
        0.74,
        0.54,
        0.55,
        0.41,
        0.47,
        0.51,
        0.56,
        0.55,
        0.75
    ]
}

data2_incline_plane = {
    "Material_1": [
        "Silver",
        "Silver",
        "Silver",
        "Silver",
        "Aluminium",
        "Aluminium",
        "Gold",
        "Gold",
        "Cadmium",
        "Cadmium",
        "Cobalt",
        "Cobalt",
        "Chromium",
        "Chromium",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Copper",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Iron",
        "Indium",
        "Manganese",
        "Molybdenum",
        "Molybdenum",
        "Niobium",
        "Nickel",
        "Nickel",
        "Nickel",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Lead",
        "Platinum",
        "Platinum",
        "Tin",
        "Tin",
        "Titanium",
        "Titanium",
        "Tungsten",
        "Tungsten",
        "Tungsten",
        "Zinc",
        "Zinc",
        "Zinc"
    ],
    "E_1": [
        83.4,   # Silver
        83.4,   # Silver
        83.4,   # Silver
        83.4,   # Silver
        69.0,    # Aluminium
        69.0,    # Aluminium
        79.0,    # Gold
        79.0,    # Gold
        28.0,    # Cadmium
        28.0,    # Cadmium
        211.0,   # Cobalt
        279.0,   # Chromium
        279.0,   # Chromium
        279.0,   # Chromium
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        110.0,   # Copper
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        211.0,   # Iron
        82.0,    # Indium
        130.0,   # Manganese
        329.0,   # Molybdenum
        329.0,   # Molybdenum
        105.0,   # Niobium
        207.0,   # Nickel
        207.0,   # Nickel
        207.0,   # Nickel
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        16.0,    # Lead
        168.0,   # Platinum
        168.0,   # Platinum
        50.0,    # Tin
        50.0,    # Tin
        116.0,   # Titanium
        116.0,   # Titanium
        411.0,   # Tungsten
        411.0,   # Tungsten
        411.0,   # Tungsten
        108.0,   # Zinc
        108.0,    # Zinc
        108.0    # Zinc
    ],
    "nu_1": [
        0.37,  # Silver
        0.37,  # Silver
        0.37,  # Silver
        0.37,  # Silver
        0.33,  # Aluminium
        0.33,  # Aluminium
        0.42,  # Gold
        0.42,  # Gold
        0.30,  # Cadmium
        0.30,  # Cadmium
        0.31,  # Cobalt
        0.31,  # Cobalt
        0.21,  # Chromium
        0.21,  # Chromium
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.34,  # Copper
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.29,  # Iron
        0.45,  # Indium
        0.21,  # Manganese
        0.31,  # Molybdenum
        0.31,  # Molybdenum
        0.40,  # Niobium
        0.31,  # Nickel
        0.31,  # Nickel
        0.31,  # Nickel
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.44,  # Lead
        0.38,  # Platinum
        0.38,  # Platinum
        0.36,  # Tin
        0.36,  # Tin
        0.34,  # Titanium
        0.34,  # Titanium
        0.28,  # Tungsten
        0.28,  # Tungsten
        0.28,  # Tungsten
        0.25,  # Zinc
        0.25,   # Zinc
        0.25   # Zinc
    ],
    "rho_1": [
        10490,  # Silver
        10490,  # Silver
        10490,  # Silver
        10490,  # Silver
        2700,   # Aluminium
        2700,   # Aluminium
        19320,  # Gold
        19320,  # Gold
        8650,   # Cadmium
        8650,   # Cadmium
        8900,   # Cobalt
        8900,   # Cobalt
        7190,   # Chromium
        7190,   # Chromium
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        8960,   # Copper
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7870,   # Iron
        7310,   # Indium
        7440,   # Manganese
        10220,  # Molybdenum
        10220,  # Molybdenum
        8570,   # Niobium
        8908,   # Nickel
        8908,   # Nickel
        8908,   # Nickel
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        11340,  # Lead
        21450,  # Platinum
        21450,  # Platinum
        7310,   # Tin
        7310,   # Tin
        4507,   # Titanium
        4507,   # Titanium
        19300,  # Tungsten
        19300,  # Tungsten
        19300,  # Tungsten
        7140,   # Zinc
        7140,    # Zinc
        7140,    # Zinc
    ],
    "Material_2": [
        "Silver",
        "Gold",
        "Copper",
        "Iron",
        "Aluminium",
        "Titanium",
        "Silver",
        "Gold",
        "Cadmium",
        "Iron",
        "Cobalt",
        "Chromium",
        "Cobalt",
        "Chromium",
        "Cobalt",
        "Chromium",
        "Copper",
        "Iron",
        "Nickel",
        "Zinc",
        "Cobalt",
        "Chromium",
        "Iron",
        "Manganese",
        "Molybdenum",
        "Titanium",
        "Tungsten",
        "Zinc",
        "Indium",
        "Manganese",
        "Iron",
        "Molybdenum",
        "Niobium",
        "Chromium",
        "Nickel",
        "Platinum",
        "Silver",
        "Gold",
        "Copper",
        "Chromium",
        "Iron",
        "Lead",
        "Nickel",
        "Platinum",
        "Iron",
        "Tin",
        "Aluminium",
        "Titanium",
        "Copper",
        "Iron",
        "Tungsten",
        "Copper",
        "Iron",
        "Zinc"
    ],
    "E_2": [
        83.0,   # Silver
        79.0,   # Gold
        110.0,  # Copper
        211.0,  # Iron
        69.0,   # Aluminium
        116.0,  # Titanium
        83.0,   # Silver
        79.0,   # Gold
        50.0,   # Cadmium
        211.0,  # Iron
        209.0,  # Cobalt
        279.0,  # Chromium
        209.0,  # Cobalt
        279.0,  # Chromium
        209.0,  # Cobalt
        279.0,  # Chromium
        110.0,  # Copper
        211.0,  # Iron
        207.0,  # Nickel
        108.0,  # Zinc
        209.0,  # Cobalt
        279.0,  # Chromium
        211.0,  # Iron
        200.0,  # Manganese
        329.0,  # Molybdenum
        116.0,  # Titanium
        411.0,  # Tungsten
        108.0,  # Zinc
        11.0,   # Indium
        200.0,  # Manganese
        211.0,  # Iron
        329.0,  # Molybdenum
        105.0,  # Niobium
        279.0,  # Chromium
        207.0,  # Nickel
        168.0,  # Platinum
        83.0,   # Silver
        79.0,   # Gold
        110.0,  # Copper
        279.0,  # Chromium
        211.0,  # Iron
        16.0,   # Lead
        207.0,  # Nickel
        168.0,  # Platinum
        211.0,  # Iron
        50.0,   # Tin
        69.0,   # Aluminium
        116.0,  # Titanium
        110.0,  # Copper
        211.0,  # Iron
        411.0,  # Tungsten
        110.0,  # Copper
        211.0,  # Iron
        108.0   # Zinc
    ],
    "nu_2": [
        0.37,  # Silver
        0.42,  # Gold
        0.34,  # Copper
        0.29,  # Iron
        0.33,  # Aluminium
        0.34,  # Titanium
        0.37,  # Silver
        0.42,  # Gold
        0.30,  # Cadmium
        0.29,  # Iron
        0.31,  # Cobalt
        0.21,  # Chromium
        0.31,  # Cobalt
        0.21,  # Chromium
        0.31,  # Cobalt
        0.21,  # Chromium
        0.34,  # Copper
        0.29,  # Iron
        0.31,  # Nickel
        0.25,  # Zinc
        0.31,  # Cobalt
        0.21,  # Chromium
        0.29,  # Iron
        0.21,  # Manganese
        0.31,  # Molybdenum
        0.34,  # Titanium
        0.28,  # Tungsten
        0.25,  # Zinc
        0.45,  # Indium
        0.21,  # Manganese
        0.29,  # Iron
        0.31,  # Molybdenum
        0.40,  # Niobium
        0.21,  # Chromium
        0.31,  # Nickel
        0.38,  # Platinum
        0.37,  # Silver
        0.42,  # Gold
        0.34,  # Copper
        0.21,  # Chromium
        0.29,  # Iron
        0.44,  # Lead
        0.31,  # Nickel
        0.38,  # Platinum
        0.29,  # Iron
        0.36,  # Tin
        0.33,  # Aluminium
        0.34,  # Titanium
        0.34,  # Copper
        0.29,  # Iron
        0.28,  # Tungsten
        0.34,  # Copper
        0.29,  # Iron
        0.25   # Zinc
    ],
    "rho_2": [
        10490,  # Silver
        19320,  # Gold
        8960,   # Copper
        7870,   # Iron
        2700,   # Aluminium
        4507,   # Titanium
        10490,  # Silver
        19320,  # Gold
        8650,   # Cadmium
        7870,   # Iron
        8900,   # Cobalt
        7190,   # Chromium
        8900,   # Cobalt
        7190,   # Chromium
        8900,   # Cobalt
        7190,   # Chromium
        8960,   # Copper
        7870,   # Iron
        8908,   # Nickel
        7140,   # Zinc
        8900,   # Cobalt
        7190,   # Chromium
        7870,   # Iron
        7440,   # Manganese
        10220,  # Molybdenum
        4507,   # Titanium
        19300,  # Tungsten
        7140,   # Zinc
        7310,   # Indium
        7440,   # Manganese
        7870,   # Iron
        10220,  # Molybdenum
        8570,   # Niobium
        7190,   # Chromium
        8908,   # Nickel
        21450,  # Platinum
        10490,  # Silver
        19320,  # Gold
        8960,   # Copper
        7190,   # Chromium
        7870,   # Iron
        11340,  # Lead
        8908,   # Nickel
        21450,  # Platinum
        7870,   # Iron
        7310,   # Tin
        2700,   # Aluminium
        4507,   # Titanium
        8960,   # Copper
        7870,   # Iron
        19300,  # Tungsten
        8960,   # Copper
        7870,   # Iron
        7140    # Zinc
    ],
    "mu_static": [
        0.5,
        0.53,
        0.48,
        0.49,
        0.57,
        0.54,
        0.53,
        0.49,
        0.79,
        0.52,
        0.56,
        0.41,
        0.41,
        0.46,
        0.44,
        0.46,
        0.55,
        0.50,
        0.49,
        0.56,
        0.41,
        0.48,
        0.51,
        0.51,
        0.46,
        0.49,
        0.47,
        0.55,
        1.46,
        0.69,
        0.46,
        0.44,
        0.46,
        0.59,
        0.50,
        0.64,
        0.73,
        0.61,
        0.55,
        0.53,
        0.54,
        0.90,
        0.64,
        0.55,
        0.55,
        0.74,
        0.54,
        0.55,
        0.41,
        0.47,
        0.51,
        0.56,
        0.55,
        0.75
    ]
}

# Create DataFrames
df_reference = pd.DataFrame(data)
df_incline_plane = pd.DataFrame(data2_incline_plane)

# Combine both datasets
df_combined = pd.concat([df_reference, df_incline_plane], ignore_index=True)


In [None]:
# Additional Feature Engineering

# 1. Polynomial Features (squared terms)
df_combined['E_1_squared'] = df_combined['E_1'] ** 2
df_combined['E_2_squared'] = df_combined['E_2'] ** 2
df_combined['nu_1_squared'] = df_combined['nu_1'] ** 2
df_combined['nu_2_squared'] = df_combined['nu_2'] ** 2
df_combined['rho_1_squared'] = df_combined['rho_1'] ** 2
df_combined['rho_2_squared'] = df_combined['rho_2'] ** 2

# 2. Higher Order Interaction Terms
df_combined['E_1_E_2'] = df_combined['E_1'] * df_combined['E_2']
df_combined['nu_1_nu_2'] = df_combined['nu_1'] * df_combined['nu_2']
df_combined['rho_1_rho_2'] = df_combined['rho_1'] * df_combined['rho_2']

# 3. Statistical Features
df_combined['E_mean'] = df_combined[['E_1', 'E_2']].mean(axis=1)
df_combined['E_std'] = df_combined[['E_1', 'E_2']].std(axis=1)
df_combined['nu_mean'] = df_combined[['nu_1', 'nu_2']].mean(axis=1)
df_combined['nu_std'] = df_combined[['nu_1', 'nu_2']].std(axis=1)
df_combined['rho_mean'] = df_combined[['rho_1', 'rho_2']].mean(axis=1)
df_combined['rho_std'] = df_combined[['rho_1', 'rho_2']].std(axis=1)

# 4. Encode Categorical Variables
# We'll use Label Encoding for simplicity. Alternatively, One-Hot Encoding can be used.
le_material_1 = LabelEncoder()
le_material_2 = LabelEncoder()

df_combined['Material_1_encoded'] = le_material_1.fit_transform(df_combined['Material_1'])
df_combined['Material_2_encoded'] = le_material_2.fit_transform(df_combined['Material_2'])

# Optionally, add a feature indicating if Material_1 and Material_2 are the same
df_combined['Same_Material'] = (df_combined['Material_1'] == df_combined['Material_2']).astype(int)


In [None]:
# Define feature columns
feature_columns = [
    'E_1', 'E_2', 'nu_1', 'nu_2', 'rho_1', 'rho_2',
    'E_diff', 'nu_diff', 'rho_diff',
    'E_ratio', 'nu_ratio', 'rho_ratio',
    'E_nu_interaction_1', 'E_rho_interaction_1', 'nu_rho_interaction_1',
    'E_nu_interaction_2', 'E_rho_interaction_2', 'nu_rho_interaction_2',
    'E_1_squared', 'E_2_squared', 'nu_1_squared', 'nu_2_squared',
    'rho_1_squared', 'rho_2_squared',
    'E_1_E_2', 'nu_1_nu_2', 'rho_1_rho_2',
    'E_mean', 'E_std', 'nu_mean', 'nu_std',
    'rho_mean', 'rho_std',
    'Material_1_encoded', 'Material_2_encoded',
    'Same_Material'
]

# Select features and target
X = df_combined[feature_columns]
y = df_combined['mu_static']

# Handle any potential infinite or NaN values
X.replace([np.inf, -np.inf], np.nan, inplace=True)
X.fillna(0, inplace=True)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Feature Scaling
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1)

In [None]:
class FrictionDataset(Dataset):
    def __init__(self, features, targets):
        self.X = features
        self.y = targets
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create Dataset instances
train_dataset = FrictionDataset(X_train_tensor, y_train_tensor)
test_dataset = FrictionDataset(X_test_tensor, y_test_tensor)

# Define DataLoader
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class RegressionNN(nn.Module):
    def __init__(self, input_dim):
        super(RegressionNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 1)
        )
    
    def forward(self, x):
        return self.network(x)

# Initialize the model
input_dim = X_train_tensor.shape[1]
model = RegressionNN(input_dim)

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# Initialize Weights & Biases (W&B)
wandb.init(
    project="FrictionAI_PyTorch",
    config={
        "model": "RegressionNN",
        "input_dim": input_dim,
        "hidden_layers": [64, 32],
        "dropout": [0.2, 0.1],
        "learning_rate": 0.001,
        "epochs": 500,
        "batch_size": batch_size
    }
)

# Define Loss Function and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

# Training Loop
epochs = wandb.config.epochs
train_losses = []
test_losses = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * X_batch.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)
    
    # Evaluation on test data
    model.eval()
    test_running_loss = 0.0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            test_running_loss += loss.item() * X_batch.size(0)
    test_epoch_loss = test_running_loss / len(test_loader.dataset)
    test_losses.append(test_epoch_loss)
    
    # Log to W&B
    wandb.log({
        "Epoch": epoch + 1,
        "Train Loss": epoch_loss,
        "Test Loss": test_epoch_loss
    })
    
    # Print progress every 50 epochs
    if (epoch + 1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_loss:.4f}, Test Loss: {test_epoch_loss:.4f}")

print("Training Completed.")


In [None]:
# Final Evaluation on Test Set
model.eval()
predictions = []
actuals = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        outputs = model(X_batch)
        predictions.extend(outputs.cpu().numpy())
        actuals.extend(y_batch.cpu().numpy())

predictions = np.array(predictions).flatten()
actuals = np.array(actuals).flatten()

# Calculate Metrics
rmse = np.sqrt(mean_squared_error(actuals, predictions))
mae = mean_absolute_error(actuals, predictions)
r2 = r2_score(actuals, predictions)

# Log metrics to W&B
wandb.log({
    "Final Evaluation": {
        "RMSE": rmse,
        "MAE": mae,
        "R2_Score": r2
    }
})

print(f"Model Performance on Test Set:")
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R2 Score: {r2:.4f}")


In [None]:
# Plot Training and Testing Loss
plt.figure(figsize=(10,6))
plt.plot(range(1, epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, epochs + 1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Testing Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Log the plot to W&B
wandb.log({"Loss Curve": plt})


In [None]:
# Scatter Plot of Actual vs Predicted
plt.figure(figsize=(8,8))
sns.scatterplot(x=actuals, y=predictions)
plt.plot([actuals.min(), actuals.max()], [actuals.min(), actuals.max()], 'r--')
plt.xlabel('Actual mu_static')
plt.ylabel('Predicted mu_static')
plt.title('Actual vs. Predicted mu_static')
plt.grid(True)
plt.show()

# Log the plot to W&B
wandb.log({"Actual vs Predicted": plt})


In [None]:
# Residual Plot
residuals = predictions - actuals
plt.figure(figsize=(10,6))
sns.histplot(residuals, kde=True, bins=20)
plt.xlabel('Residuals')
plt.title('Distribution of Residuals')
plt.show()

# Log the plot to W&B
wandb.log({"Residuals Distribution": plt})


In [None]:
# Prepare New Data
df_new = df_incline_plane.copy()

# Feature Engineering for New Data
df_new['E_diff'] = df_new['E_1'] - df_new['E_2']
df_new['nu_diff'] = df_new['nu_1'] - df_new['nu_2']
df_new['rho_diff'] = df_new['rho_1'] - df_new['rho_2']
df_new['E_ratio'] = df_new['E_1'] / df_new['E_2']
df_new['nu_ratio'] = df_new['nu_1'] / df_new['nu_2']
df_new['rho_ratio'] = df_new['rho_1'] / df_new['rho_2']

# Interaction features
df_new['E_nu_interaction_1'] = df_new['E_1'] * df_new['nu_1']
df_new['E_rho_interaction_1'] = df_new['E_1'] * df_new['rho_1']
df_new['nu_rho_interaction_1'] = df_new['nu_1'] * df_new['rho_1']

df_new['E_nu_interaction_2'] = df_new['E_2'] * df_new['nu_2']
df_new['E_rho_interaction_2'] = df_new['E_2'] * df_new['rho_2']
df_new['nu_rho_interaction_2'] = df_new['nu_2'] * df_new['rho_2']

# Additional Feature Engineering
df_new['E_1_squared'] = df_new['E_1'] ** 2
df_new['E_2_squared'] = df_new['E_2'] ** 2
df_new['nu_1_squared'] = df_new['nu_1'] ** 2
df_new['nu_2_squared'] = df_new['nu_2'] ** 2
df_new['rho_1_squared'] = df_new['rho_1'] ** 2
df_new['rho_2_squared'] = df_new['rho_2'] ** 2

df_new['E_1_E_2'] = df_new['E_1'] * df_new['E_2']
df_new['nu_1_nu_2'] = df_new['nu_1'] * df_new['nu_2']
df_new['rho_1_rho_2'] = df_new['rho_1'] * df_new['rho_2']

df_new['E_mean'] = df_new[['E_1', 'E_2']].mean(axis=1)
df_new['E_std'] = df_new[['E_1', 'E_2']].std(axis=1)
df_new['nu_mean'] = df_new[['nu_1', 'nu_2']].mean(axis=1)
df_new['nu_std'] = df_new[['nu_1', 'nu_2']].std(axis=1)
df_new['rho_mean'] = df_new[['rho_1', 'rho_2']].mean(axis=1)
df_new['rho_std'] = df_new[['rho_1', 'rho_2']].std(axis=1)

# Encode Categorical Variables using the same encoders
df_new['Material_1_encoded'] = le_material_1.transform(df_new['Material_1'])
df_new['Material_2_encoded'] = le_material_2.transform(df_new['Material_2'])

# Add Same_Material Feature
df_new['Same_Material'] = (df_new['Material_1'] == df_new['Material_2']).astype(int)

# Select Features
X_new = df_new[feature_columns]

# Handle any potential infinite or NaN values
X_new.replace([np.inf, -np.inf], np.nan, inplace=True)
X_new.fillna(0, inplace=True)

# Scale the new features using the previously fitted scaler
X_new_scaled = scaler.transform(X_new)

# Convert to PyTorch tensor
X_new_tensor = torch.tensor(X_new_scaled, dtype=torch.float32).to(device)

# Make Predictions
model.eval()
with torch.no_grad():
    mu_pred_new = model(X_new_tensor).cpu().numpy().flatten()

# Add predictions to the DataFrame
df_new['mu_pred'] = mu_pred_new
df_new['error'] = abs(df_new['mu_pred'] - df_new['mu_static'])

# Log predictions to W&B
for idx, row in df_new.iterrows():
    wandb.log({
        "Material": row['Material_1'],
        "Actual mu_static": row['mu_static'],
        "Predicted mu_static": row['mu_pred'],
        "Error": row['error']
    })

# Display a few predictions
display(df_new[['Material_1', 'mu_static', 'mu_pred', 'error']].head())

In [None]:
# Save the model state_dict
torch.save(model.state_dict(), 'friction_model.pth')
print("Model saved to friction_model.pth")

# To load the model later:
# model = RegressionNN(input_dim)
# model.load_state_dict(torch.load('friction_model.pth'))
# model.to(device)
# model.eval()

In [None]:
wandb.finish()