# Fabric GSM Prediction ‚Äî Feature-Based Regression (CatBoost)

This notebook trains a regression model to predict fabric GSM using engineered features extracted from images. Images are used only for visualization/sanity checks; the model input is purely tabular features. The design targets an MAE within ¬±5‚Äì10 GSM if features are sufficiently informative.

## 1Ô∏è‚É£ Imports & Configuration
- Structured, publication-ready code with seeded randomness
- Uses CatBoost with optional GPU (devices='0'), MAE loss
- Paths configurable for local or Google Drive (Colab)

In [None]:
# Environment & Base Path: local vs Colab (Google Drive)
import os
from pathlib import Path

try:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')
    IN_COLAB = True
    BASE_PATH = Path('/content/drive/MyDrive/fabric_gsm_pipeline')
    print('Running in Colab; mounted Google Drive.')
except Exception:
    IN_COLAB = False
    # Use workspace root (two levels up from this notebook's train folder)
    BASE_PATH = Path.cwd()  # Adjust if you want a specific folder
    print('Running locally; using workspace path:', BASE_PATH)

# Optional: override via env var GSM_BASE_PATH
env_base = os.environ.get('GSM_BASE_PATH')
if env_base:
    BASE_PATH = Path(env_base)
    print('BASE_PATH overridden by GSM_BASE_PATH:', BASE_PATH)

In [None]:
# Imports & Core Configuration
import json
import math
import time
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.feature_selection import VarianceThreshold
from sklearn.metrics import mean_absolute_error, mean_squared_error

try:
    from catboost import CatBoostRegressor, Pool
    CATBOOST_AVAILABLE = True
except Exception:
    CATBOOST_AVAILABLE = False
    print('CatBoost import failed; please install catboost.')

# Plotting style
sns.set(style='whitegrid', context='notebook')

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# Prefer datasets that include a 'split' column (train/val/test).
def bp(*parts):
    return BASE_PATH.joinpath(*parts)

DATASET_CANDIDATES = [
    bp('split_feature_dataset', 'dataset_all.csv'),
    bp('feature_extracted_dataset', 'dataset_with_features.csv'),
    bp('data', 'augmented_features_dataset', 'dataset_with_features.csv'),
    bp('Dataset', 'dataset.csv'),
    bp('data', 'combined_dataset', 'dataset.csv')
]

def pick_dataset_path(candidates):
    for p in candidates:
        if Path(p).exists():
            return Path(p)
    return Path(candidates[0])  # default fallback

DATASET_PATH = pick_dataset_path(DATASET_CANDIDATES)
print(f'Using dataset: {DATASET_PATH}')

# Image search roots for sanity checks
IMAGE_DIR_CANDIDATES = [
    bp('split_feature_dataset', 'train', 'images'),
    bp('split_feature_dataset', 'val', 'images'),
    bp('split_feature_dataset', 'test', 'images'),
    bp('feature_extracted_dataset', 'images'),
    bp('preprocessed_dataset', 'images'),
    bp('Dataset', 'images'),
    bp('data', 'augmented_dataset', 'images'),
    bp('data', 'augmented_features_dataset', 'images'),
    bp('data', 'feature_extracted_dataset', 'images'),
    bp('data', 'preprocessed_dataset', 'images')
]

# Columns to explicitly exclude from features
EXCLUDE_COLS = set([
    'image_name', 'source', 'augmentation', 'original_image', 'split',
    'gsm'  # target kept separately
])

# Output directories
OUTPUT_DIR = bp('train')
VIS_DIR = bp('feature_extracted_dataset', 'visualizations') if bp('feature_extracted_dataset').exists() else bp('train', 'visualizations')
MODEL_DIR = bp('Model')
PREDICTIONS_PATH = bp('train', 'predictions_gsm_feature_catboost.csv')
MODEL_PATH = MODEL_DIR / 'gsm_feature_catboost.cbm'
Path(VIS_DIR).mkdir(parents=True, exist_ok=True)
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)

def has_gpu_for_catboost() -> bool:
    if not CATBOOST_AVAILABLE:
        return False
    try:
        if os.environ.get('CUDA_VISIBLE_DEVICES', '') not in ['', '-1']:
            return True
    except Exception:
        pass
    try:
        import subprocess
        res = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=3, shell=False)
        if res.returncode == 0 and len(res.stdout.decode().strip()) > 0:
            return True
    except Exception:
        pass
    return False

GPU_AVAILABLE = has_gpu_for_catboost()
print(f'GPU available for CatBoost: {GPU_AVAILABLE}')

## 2Ô∏è‚É£ Data Loading & Inspection
- Load CSV and inspect shape/dtypes
- Identify numeric vs non-numeric columns
- Explicitly exclude meta columns from features

In [None]:
# Load dataset
assert Path(DATASET_PATH).exists(), f'Dataset not found: {DATASET_PATH}'
df = pd.read_csv(DATASET_PATH)
print('Shape:', df.shape)
print('Columns:', list(df.columns))

# Basic checks
assert 'gsm' in df.columns, 'Target column gsm is missing.'

# Dtypes summary
print('
Dtypes:')
print(df.dtypes)

# Identify numeric vs non-numeric columns
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
non_numeric_cols = [c for c in df.columns if c not in numeric_cols]
print('
Numeric columns (count={}):'.format(len(numeric_cols)))
print(numeric_cols[:20], '...')
print('
Non-numeric columns (count={}):'.format(len(non_numeric_cols)))
print(non_numeric_cols)

# Exclusions
print('
Explicitly excluding columns from features:', sorted(EXCLUDE_COLS))

# Preview
display(df.head(3))

## 3Ô∏è‚É£ Train / Validation Split (CRITICAL)
- Use the existing `split` column
- Ensure augmented samples remain with their original image's split
- Prevent leakage between train and val

In [None]:
# Ensure 'split' column exists
assert 'split' in df.columns, 'The dataset must contain a split column.'

# Group by original identifier to enforce consistent split. Prefer 'original_image' then fallback to 'image_name'.
group_key = 'original_image' if 'original_image' in df.columns else ('image_name' if 'image_name' in df.columns else None)
assert group_key is not None, 'Dataset should have either original_image or image_name for grouping.'

# Detect inconsistent splits within a group and fix by majority vote
inconsistent_groups = []
def majority_split(s):
    return s.value_counts().idxmax()

df['split_fixed'] = df['split']
for gid, g in df.groupby(group_key):
    if g['split'].nunique() > 1:
        inconsistent_groups.append(gid)
        maj = majority_split(g['split'])
        df.loc[g.index, 'split_fixed'] = maj

if inconsistent_groups:
    print(f'Found {len(inconsistent_groups)} groups with inconsistent splits; fixed via majority assignment.')
else:
    print('No inconsistent group splits detected.')

# Define train/val based on fixed splits; test is unused for training
train_df = df[df['split_fixed'].str.lower() == 'train'].copy()
val_df = df[df['split_fixed'].str.lower() == 'val'].copy()
print(f'Train: {train_df.shape}, Val: {val_df.shape}')

# Sanity: ensure groups do not cross splits
train_groups = set(train_df[group_key].unique())
val_groups = set(val_df[group_key].unique())
intersection = train_groups.intersection(val_groups)
assert len(intersection) == 0, f'Leakage detected! {len(intersection)} groups appear in both train and val.'
print('No group leakage between train and val.')

## 4Ô∏è‚É£ Feature Cleaning
- Keep only numeric features
- Drop near-constant features
- Handle missing values safely
- Optional: prune highly correlated features

In [None]:
# Build feature set: numeric columns excluding meta/target
feature_cols = [c for c in numeric_cols if c not in EXCLUDE_COLS and c != 'gsm']
print(f'Initial numeric feature count: {len(feature_cols)}')

# Separate X and y
X_train = train_df[feature_cols].copy()
X_val = val_df[feature_cols].copy()
y_train_raw = train_df['gsm'].astype(float).copy()
y_val_raw = val_df['gsm'].astype(float).copy()

# Handle NaNs: CatBoost can handle NaNs, but imputing reduces surprises in plots/stats
median_vals = X_train.median()
X_train = X_train.fillna(median_vals)
X_val = X_val.fillna(median_vals)  # use train medians for val

# Drop near-constant features (variance threshold)
vt = VarianceThreshold(threshold=1e-8)
_ = vt.fit(X_train)
mask = vt.get_support()
feature_cols_vt = [col for col, keep in zip(feature_cols, mask) if keep]
print(f'After variance filter: {len(feature_cols_vt)} features (dropped {len(feature_cols) - len(feature_cols_vt)})')

X_train = X_train[feature_cols_vt]
X_val = X_val[feature_cols_vt]

# Optional: correlation-based pruning to reduce multicollinearity
corr_threshold = 0.98
corr = X_train.corr(numeric_only=True).abs()
upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
to_drop = [column for column in upper.columns if any(upper[column] > corr_threshold)]
feature_cols_pruned = [c for c in feature_cols_vt if c not in to_drop]
print(f'Correlation pruning: dropped {len(to_drop)} features above {corr_threshold}')

X_train = X_train[feature_cols_pruned]
X_val = X_val[feature_cols_pruned]
print(f'Final feature count: {len(feature_cols_pruned)}')

## 5Ô∏è‚É£ Target Engineering (IMPORTANT)
We transform the target using $y=\log(1+\text{GSM})$ to reduce skew and stabilize the training dynamics. Predictions are inverse-transformed via $\exp(y)-1$.

In [None]:
# Apply log1p transformation to target
y_train = np.log1p(y_train_raw.values)
y_val = np.log1p(y_val_raw.values)

print('Target stats (raw GSM):')
print(pd.Series(y_val_raw).describe())
print('
Target stats (log1p GSM):')
print(pd.Series(y_val).describe())

## 6Ô∏è‚É£ Model Selection
We use `CatBoostRegressor` because it performs strongly on structured/tabular data, models feature interactions well, and is stable on smaller datasets.
- Loss: MAE (aligned with GSM error target)
- GPU support: `task_type='GPU'`, `devices='0'` when available
- CPU fallback when GPU is unavailable

In [None]:
# Prepare CatBoost Pools
feature_names = list(X_train.columns)
train_pool = Pool(X_train, label=y_train, feature_names=feature_names)
val_pool = Pool(X_val, label=y_val, feature_names=feature_names)

# Model parameters
cat_params = {
    'loss_function': 'MAE',
    'eval_metric': 'MAE',
    'random_seed': SEED,
    'depth': 6,
    'learning_rate': 0.05,
    'iterations': 2000,
    'early_stopping_rounds': 100,
    'use_best_model': True,
    'verbose': 100
}

if GPU_AVAILABLE:
    cat_params.update({'task_type': 'GPU', 'devices': '0'})
else:
    cat_params.update({'task_type': 'CPU'})

print('CatBoost params:', cat_params)

# Initialize model
assert CATBOOST_AVAILABLE, 'CatBoost is not installed; pip install catboost'
model = CatBoostRegressor(**cat_params)

## 7Ô∏è‚É£ Training Strategy
- Train on train split and validate on val split
- Use early stopping and best-iteration selection
- No cross-validation abuse to avoid leakage from augmentations

In [None]:
# Train model with early stopping
start_time = time.time()
model.fit(train_pool, eval_set=val_pool)
train_time = time.time() - start_time
print(f'Training completed in {train_time:.2f} seconds')
print('Best iteration:', model.get_best_iteration())
print('Best score (val MAE):', model.get_best_score().get('validation', {}).get('MAE'))

## 8Ô∏è‚É£ Evaluation
- Inverse-transform predictions back to GSM
- Report MAE, RMSE, and mean GSM
- Plot Actual vs Predicted and error distribution

In [None]:
# Predict and invert transformation
val_pred_log = model.predict(val_pool)
val_pred = np.expm1(val_pred_log)

# Metrics
mae = mean_absolute_error(y_val_raw, val_pred)
rmse = math.sqrt(mean_squared_error(y_val_raw, val_pred))
mean_gsm = float(np.mean(y_val_raw))
print(f'Validation MAE: {mae:.3f} GSM')
print(f'Validation RMSE: {rmse:.3f} GSM')
print(f'Mean GSM (val): {mean_gsm:.3f}')

# Scatter: Actual vs Predicted
plt.figure(figsize=(7,6))
plt.scatter(y_val_raw, val_pred, s=18, alpha=0.7, edgecolor='none')
min_v = float(min(y_val_raw.min(), val_pred.min()))
max_v = float(max(y_val_raw.max(), val_pred.max()))
plt.plot([min_v, max_v], [min_v, max_v], 'r--', lw=1)
plt.title('Actual vs Predicted GSM (Validation)')
plt.xlabel('Actual GSM')
plt.ylabel('Predicted GSM')
plt.tight_layout()
plt.savefig(VIS_DIR / 'actual_vs_pred_val.png', dpi=120)
plt.show()

# Error distribution
errors = np.abs(y_val_raw.values - val_pred)
plt.figure(figsize=(7,5))
sns.histplot(errors, bins=30, kde=True, color='steelblue')
plt.title('Absolute Error Distribution (Validation)')
plt.xlabel('Absolute Error (GSM)')
plt.tight_layout()
plt.savefig(VIS_DIR / 'error_distribution_val.png', dpi=120)
plt.show()

## 9Ô∏è‚É£ Feature Importance
We inspect CatBoost's feature importances to understand which engineered features drive GSM predictions.

In [None]:
# Feature importance
importances = model.get_feature_importance(train_pool)
imp_df = pd.DataFrame({'feature': feature_names, 'importance': importances})
imp_df = imp_df.sort_values('importance', ascending=False)
display(imp_df.head(20))

plt.figure(figsize=(8,8))
topn = 20
sns.barplot(y=imp_df['feature'].head(topn), x=imp_df['importance'].head(topn), color='teal')
plt.title(f'Top {topn} Feature Importances (CatBoost)')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.tight_layout()
plt.savefig(VIS_DIR / 'feature_importance_top20.png', dpi=120)
plt.show()

Interpretation: Higher importances indicate stronger relationships with GSM. Expect structural features (e.g., weft/warp counts, spacing stats), yarn uniformity, and frequency/texture descriptors to rank highly. If color features dominate, revisit lighting normalization or feature extraction.

## üîü Visual Sanity Check
Randomly sample 10 validation images; display image with actual and predicted GSM and absolute error.

In [None]:
# Helper to find an image by name across known directories
def find_image_path(image_name: str, split_hint: str = None):
    candidates = []
    if split_hint:
        candidates.append(bp('split_feature_dataset', split_hint, 'images') / image_name)
    for root in IMAGE_DIR_CANDIDATES:
        candidates.append(root / image_name)
    for p in candidates:
        if Path(p).exists():
            return p
    return None

# Prepare display data
val_display = val_df.copy()
val_display['pred_gsm'] = val_pred
val_display['abs_error'] = np.abs(val_display['gsm'] - val_display['pred_gsm'])
sample_n = min(10, len(val_display))
sample = val_display.sample(sample_n, random_state=SEED)

# Plot grid
ncols = 5
nrows = math.ceil(sample_n / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.5, nrows*3.5))
axes = np.array(axes).reshape(-1)

for ax in axes[sample_n:]:
    ax.axis('off')

for i, (_, row) in enumerate(sample.iterrows()):
    ax = axes[i]
    img_name = row['image_name'] if 'image_name' in row else None
    img_path = find_image_path(str(img_name) if img_name is not None else '', split_hint='val') if img_name else None
    if img_path and Path(img_path).exists():
        try:
            img = plt.imread(str(img_path))
            ax.imshow(img)
        except Exception:
            ax.text(0.5, 0.5, 'Image load failed', ha='center', va='center')
            ax.set_facecolor('#f7f7f7')
    else:
        ax.text(0.5, 0.5, 'Image not found', ha='center', va='center')
        ax.set_facecolor('#f7f7f7')
    ax.set_xticks([]); ax.set_yticks([])
    title = f"Actual: {row['gsm']:.1f} | Pred: {row['pred_gsm']:.1f} | |Err|: {row['abs_error']:.1f}"
    ax.set_title(title, fontsize=9)

plt.suptitle('Validation Samples: Visual Sanity Check', y=0.98)
plt.tight_layout()
plt.savefig(VIS_DIR / 'val_visual_sanity_check.png', dpi=120)
plt.show()

## 1Ô∏è‚É£1Ô∏è‚É£ Error Analysis
Identify the worst 5 predictions and discuss likely causes: structural ambiguity (mixed weave patterns), yarn overlaps, motion blur, illumination variances, or features missing key discriminants.

In [None]:
# Worst 5 predictions
worst = val_display.sort_values('abs_error', ascending=False).head(5)[['image_name', 'gsm', 'pred_gsm', 'abs_error']]
print('Worst 5 (by absolute error):')
display(worst)

## 1Ô∏è‚É£2Ô∏è‚É£ Model Saving
Save the trained CatBoost model (.cbm) and export validation predictions to CSV for further analysis/plots.

In [None]:
# Save model and predictions
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
model.save_model(str(MODEL_PATH))
print(f'Saved CatBoost model to: {MODEL_PATH}')

pred_out_cols = [group_key]
if 'image_name' in val_df.columns:
    pred_out_cols.append('image_name')
pred_out_cols += ['gsm', 'pred_gsm', 'abs_error']
pred_out = val_display[pred_out_cols].copy()
pred_out.to_csv(PREDICTIONS_PATH, index=False)
print(f'Saved validation predictions to: {PREDICTIONS_PATH}')

# Save feature list and config for reproducibility
meta = {
    'dataset_path': str(DATASET_PATH),
    'features_used': feature_names,
    'final_features': list(X_train.columns),
    'seed': SEED,
    'gpu_used': GPU_AVAILABLE,
    'catboost_params': cat_params,
    'val_metrics': {'MAE': float(mae), 'RMSE': float(rmse), 'Mean_GSM_val': float(mean_gsm)}
}
with open(bp('train', 'training_metadata.json'), 'w') as f:
    json.dump(meta, f, indent=2)
print('Saved training metadata to train/training_metadata.json')

### Notes
- This pipeline respects pre-defined splits to avoid leakage from augmentations.
- Target log1p helps stabilize training; always invert before reporting.
- If MAE is above 10 GSM, investigate feature extraction quality, illumination normalization, and split integrity.