# üßµ GSM Prediction from Fabric Microscopy Images

## Hybrid Deep Learning Approach (No Augmentation)

**Research Objective:** Develop an accurate GSM prediction model using:
- Pre-trained CNN features (EfficientNet-B3)
- 64 engineered fabric-specific features

**Target Accuracy:** ¬±5 GSM prediction error

**Dataset:** 177 original microscopy images with extracted features (70% train, 15% val, 15% test)

---

### Quick Start (Google Colab)
1. Upload `split_feature_dataset` folder to Google Drive
2. Mount Drive and set `DATASET_PATH` below
3. Run all cells sequentially
4. Model will be saved to Drive after training

## 1. Environment Setup & GPU Configuration

In [None]:
# Check GPU availability
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("\n‚úÖ Environment configured with seed:", SEED)

## 2. Mount Google Drive & Load Dataset

In [None]:
# Mount Google Drive (for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    BASE_PATH = '/content/drive/MyDrive/fabric_gsm_pipeline'
except:
    IN_COLAB = False
    BASE_PATH = 'data'
    print("Running locally")

# Dataset paths
DATASET_PATH = f"{BASE_PATH}/split_feature_dataset"
TRAIN_IMAGES = f"{DATASET_PATH}/train/images"
VAL_IMAGES = f"{DATASET_PATH}/val/images"
TEST_IMAGES = f"{DATASET_PATH}/test/images"
TRAIN_CSV = f"{DATASET_PATH}/train/dataset_train.csv"
VAL_CSV = f"{DATASET_PATH}/val/dataset_val.csv"
TEST_CSV = f"{DATASET_PATH}/test/dataset_test.csv"

print(f"Dataset path: {DATASET_PATH}")

## 3. Import Libraries & Visualization Setup

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Sklearn imports
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Plotting configuration
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

print("‚úÖ All libraries imported successfully")

## 4. Load and Explore Dataset

In [None]:
# Load datasets
df_train = pd.read_csv(TRAIN_CSV)
df_val = pd.read_csv(VAL_CSV)
df_test = pd.read_csv(TEST_CSV)

print("="*80)
print("üìä DATASET STATISTICS")
print("="*80)
print(f"Train samples: {len(df_train)}")
print(f"Val samples:   {len(df_val)}")
print(f"Test samples:  {len(df_test)}")
print(f"Total:         {len(df_train) + len(df_val) + len(df_test)}")

# Feature columns (exclude metadata)
meta_cols = ['image_name', 'gsm', 'source']
feature_cols = [col for col in df_train.columns if col not in meta_cols]

print(f"\nüî¨ Extracted features: {len(feature_cols)}")
print(f"Feature names: {feature_cols[:5]}... (showing first 5)")

# GSM distribution
print("\nüìä GSM Distribution:")
print(f"Train - Mean: {df_train['gsm'].mean():.2f}, Std: {df_train['gsm'].std():.2f}, Range: [{df_train['gsm'].min():.0f}, {df_train['gsm'].max():.0f}]")
print(f"Val   - Mean: {df_val['gsm'].mean():.2f}, Std: {df_val['gsm'].std():.2f}, Range: [{df_val['gsm'].min():.0f}, {df_val['gsm'].max():.0f}]")
print(f"Test  - Mean: {df_test['gsm'].mean():.2f}, Std: {df_test['gsm'].std():.2f}, Range: [{df_test['gsm'].min():.0f}, {df_test['gsm'].max():.0f}]")

# Visualize GSM distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, (df, name) in enumerate([(df_train, 'Train'), (df_val, 'Val'), (df_test, 'Test')]):
    axes[i].hist(df['gsm'], bins=20, alpha=0.7, edgecolor='black')
    axes[i].set_title(f'{name} GSM Distribution')
    axes[i].set_xlabel('GSM (g/m¬≤)')
    axes[i].set_ylabel('Frequency')
    axes[i].axvline(df['gsm'].mean(), color='red', linestyle='--', label=f"Mean: {df['gsm'].mean():.1f}")
    axes[i].legend()
plt.tight_layout()
plt.show()

print("\n‚úÖ Dataset loaded and explored")

## 5. Feature Engineering & Preprocessing

In [None]:
# Handle missing values
print("üîß Preprocessing extracted features...")

# Fill NaN with median for each feature
for col in feature_cols:
    if df_train[col].isna().any():
        median_val = df_train[col].median()
        df_train[col].fillna(median_val, inplace=True)
        df_val[col].fillna(median_val, inplace=True)
        df_test[col].fillna(median_val, inplace=True)

# Remove features with zero variance
zero_var_cols = []
for col in feature_cols:
    if df_train[col].std() == 0:
        zero_var_cols.append(col)

if zero_var_cols:
    print(f"Removing {len(zero_var_cols)} zero-variance features: {zero_var_cols}")
    feature_cols = [col for col in feature_cols if col not in zero_var_cols]

# Standardize features using RobustScaler
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(df_train[feature_cols])
X_val_scaled = scaler.transform(df_val[feature_cols])
X_test_scaled = scaler.transform(df_test[feature_cols])

print(f"\n‚úÖ Features preprocessed: {len(feature_cols)} features")
print(f"Scaled shapes - Train: {X_train_scaled.shape}, Val: {X_val_scaled.shape}, Test: {X_test_scaled.shape}")

## 6. Custom Dataset Class (Hybrid: Images + Features)

In [None]:
class FabricGSMDataset(Dataset):
    """Dataset combining images and engineered features for GSM prediction."""
    
    def __init__(self, dataframe, features_array, images_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.features = features_array
        self.images_dir = images_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.df.iloc[idx]['image_name']
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Get engineered features
        features = torch.tensor(self.features[idx], dtype=torch.float32)
        
        # Get target GSM
        gsm = torch.tensor(self.df.iloc[idx]['gsm'], dtype=torch.float32)
        
        return image, features, gsm

# Data augmentation and normalization
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = FabricGSMDataset(df_train, X_train_scaled, TRAIN_IMAGES, transform=train_transform)
val_dataset = FabricGSMDataset(df_val, X_val_scaled, VAL_IMAGES, transform=val_test_transform)
test_dataset = FabricGSMDataset(df_test, X_test_scaled, TEST_IMAGES, transform=val_test_transform)

# Create dataloaders
BATCH_SIZE = 16  # Smaller batch size for smaller dataset
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"‚úÖ Datasets created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

## 7. Hybrid Deep Learning Model Architecture

In [None]:
class HybridGSMPredictor(nn.Module):
    """Hybrid model combining EfficientNet-B3 CNN with engineered fabric features."""
    
    def __init__(self, num_features, dropout=0.5):
        super(HybridGSMPredictor, self).__init__()
        
        # Pre-trained EfficientNet-B3 backbone
        efficientnet = models.efficientnet_b3(weights='IMAGENET1K_V1')
        
        # Freeze early layers (feature extraction)
        for param in list(efficientnet.parameters())[:-30]:
            param.requires_grad = False
        
        # Remove classifier head
        self.cnn_features = nn.Sequential(*list(efficientnet.children())[:-1])
        cnn_feature_size = 1536  # EfficientNet-B3 output
        
        # Feature processing branch
        self.feature_branch = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout/2)
        )
        
        # Fusion and prediction head
        combined_size = cnn_feature_size + 128
        self.fusion = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout/2),
            nn.Linear(256, 1)
        )
        
    def forward(self, images, features):
        # Extract CNN features
        cnn_out = self.cnn_features(images)
        cnn_out = torch.flatten(cnn_out, 1)
        
        # Process engineered features
        feat_out = self.feature_branch(features)
        
        # Concatenate and predict
        combined = torch.cat([cnn_out, feat_out], dim=1)
        output = self.fusion(combined)
        
        return output.squeeze()

# Initialize model
model = HybridGSMPredictor(num_features=len(feature_cols), dropout=0.5)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*80)
print("üß† MODEL ARCHITECTURE")
print("="*80)
print(f"Backbone: EfficientNet-B3 (ImageNet pretrained)")
print(f"Input features: {len(feature_cols)} fabric-specific features")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("="*80)

## 8. Training Configuration & Loss Functions

In [None]:
# Hyperparameters
EPOCHS = 150  # More epochs for smaller dataset
INITIAL_LR = 0.001
WEIGHT_DECAY = 1e-4
PATIENCE = 20  # Increased patience for smaller dataset

# Custom loss function
class HuberLoss(nn.Module):
    """Huber loss - robust to outliers."""
    def __init__(self, delta=1.0):
        super(HuberLoss, self).__init__()
        self.delta = delta
        
    def forward(self, pred, target):
        error = pred - target
        abs_error = torch.abs(error)
        quadratic = torch.clamp(abs_error, max=self.delta)
        linear = abs_error - quadratic
        loss = 0.5 * quadratic**2 + self.delta * linear
        return loss.mean()

# Loss and optimizer
criterion = HuberLoss(delta=5.0)
optimizer = optim.AdamW(model.parameters(), lr=INITIAL_LR, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=7, verbose=True, min_lr=1e-6
)

print("‚úÖ Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Initial LR: {INITIAL_LR}")
print(f"  Loss: Huber (delta=5.0)")
print(f"  Optimizer: AdamW with weight decay")
print(f"  Scheduler: ReduceLROnPlateau")
print(f"  Early stopping patience: {PATIENCE}")

## 9. Training Loop with Early Stopping

In [None]:
def evaluate_model(model, dataloader, criterion, device):
    """Evaluate model and return metrics."""
    model.eval()
    total_loss = 0
    predictions = []
    actuals = []
    
    with torch.no_grad():
        for images, features, targets in dataloader:
            images = images.to(device)
            features = features.to(device)
            targets = targets.to(device)
            
            outputs = model(images, features)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            predictions.extend(outputs.cpu().numpy())
            actuals.extend(targets.cpu().numpy())
    
    predictions = np.array(predictions)
    actuals = np.array(actuals)
    
    mae = mean_absolute_error(actuals, predictions)
    rmse = np.sqrt(mean_squared_error(actuals, predictions))
    r2 = r2_score(actuals, predictions)
    
    return total_loss / len(dataloader), mae, rmse, r2, predictions, actuals

# Training history
history = {
    'train_loss': [], 'val_loss': [],
    'train_mae': [], 'val_mae': [],
    'train_rmse': [], 'val_rmse': [],
    'train_r2': [], 'val_r2': [],
    'lr': []
}

best_val_mae = float('inf')
epochs_no_improve = 0
best_model_state = None

print("\n" + "="*80)
print("üöÄ TRAINING STARTED")
print("="*80)

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0
    train_preds = []
    train_actuals = []
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for images, features, targets in pbar:
        images = images.to(device)
        features = features.to(device)
        targets = targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images, features)
        loss = criterion(outputs, targets)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += loss.item()
        train_preds.extend(outputs.detach().cpu().numpy())
        train_actuals.extend(targets.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Calculate training metrics
    train_preds = np.array(train_preds)
    train_actuals = np.array(train_actuals)
    train_mae = mean_absolute_error(train_actuals, train_preds)
    train_rmse = np.sqrt(mean_squared_error(train_actuals, train_preds))
    train_r2 = r2_score(train_actuals, train_preds)
    
    # Validation phase
    val_loss, val_mae, val_rmse, val_r2, val_preds, val_actuals = evaluate_model(
        model, val_loader, criterion, device
    )
    
    # Update learning rate
    scheduler.step(val_mae)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss / len(train_loader))
    history['val_loss'].append(val_loss)
    history['train_mae'].append(train_mae)
    history['val_mae'].append(val_mae)
    history['train_rmse'].append(train_rmse)
    history['val_rmse'].append(val_rmse)
    history['train_r2'].append(train_r2)
    history['val_r2'].append(val_r2)
    history['lr'].append(current_lr)
    
    # Print epoch results
    print(f"\nEpoch {epoch+1}/{EPOCHS}:")
    print(f"  Train - Loss: {train_loss/len(train_loader):.4f}, MAE: {train_mae:.3f}, RMSE: {train_rmse:.3f}, R¬≤: {train_r2:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, MAE: {val_mae:.3f}, RMSE: {val_rmse:.3f}, R¬≤: {val_r2:.4f}")
    print(f"  LR: {current_lr:.6f}")
    
    # Early stopping and best model saving
    if val_mae < best_val_mae:
        best_val_mae = val_mae
        epochs_no_improve = 0
        best_model_state = model.state_dict().copy()
        print(f"  ‚úÖ New best model! Val MAE: {val_mae:.3f}")
    else:
        epochs_no_improve += 1
        print(f"  ‚è≥ No improvement for {epochs_no_improve} epochs")
    
    if epochs_no_improve >= PATIENCE:
        print(f"\n‚èπÔ∏è Early stopping triggered after {epoch+1} epochs")
        break
    
    print("-" * 80)

# Load best model
model.load_state_dict(best_model_state)
print(f"\n‚úÖ Training complete! Best Val MAE: {best_val_mae:.3f}")

## 10. Training History Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss plot
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_title('Loss over Epochs', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# MAE plot
axes[0, 1].plot(history['train_mae'], label='Train MAE', linewidth=2)
axes[0, 1].plot(history['val_mae'], label='Val MAE', linewidth=2)
axes[0, 1].axhline(y=5, color='r', linestyle='--', label='Target: ¬±5 GSM', linewidth=2)
axes[0, 1].set_title('Mean Absolute Error', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('MAE (GSM)')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# RMSE plot
axes[1, 0].plot(history['train_rmse'], label='Train RMSE', linewidth=2)
axes[1, 0].plot(history['val_rmse'], label='Val RMSE', linewidth=2)
axes[1, 0].set_title('Root Mean Squared Error', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('RMSE (GSM)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# R¬≤ plot
axes[1, 1].plot(history['train_r2'], label='Train R¬≤', linewidth=2)
axes[1, 1].plot(history['val_r2'], label='Val R¬≤', linewidth=2)
axes[1, 1].set_title('R¬≤ Score', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('R¬≤')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{DATASET_PATH}/training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training history saved")

## 11. Final Evaluation on Test Set

In [None]:
# Evaluate on test set
test_loss, test_mae, test_rmse, test_r2, test_preds, test_actuals = evaluate_model(
    model, test_loader, criterion, device
)

print("="*80)
print("üìä FINAL TEST SET RESULTS")
print("="*80)
print(f"Test Loss:      {test_loss:.4f}")
print(f"Test MAE:       {test_mae:.3f} GSM")
print(f"Test RMSE:      {test_rmse:.3f} GSM")
print(f"Test R¬≤:        {test_r2:.4f}")
print(f"\nüéØ Target: ¬±5 GSM prediction error")
print(f"‚úÖ Achieved: ¬±{test_mae:.2f} GSM (MAE)")

if test_mae <= 5.0:
    print("\nüéâ SUCCESS! Model meets the ¬±5 GSM accuracy target!")
else:
    print(f"\n‚ö†Ô∏è Model is {test_mae - 5:.2f} GSM away from target")

print("="*80)

# Error distribution
errors = test_preds - test_actuals
within_5 = np.sum(np.abs(errors) <= 5) / len(errors) * 100
within_10 = np.sum(np.abs(errors) <= 10) / len(errors) * 100

print(f"\nüìà Error Analysis:")
print(f"  Predictions within ¬±5 GSM:  {within_5:.1f}%")
print(f"  Predictions within ¬±10 GSM: {within_10:.1f}%")
print(f"  Max error: {np.abs(errors).max():.2f} GSM")
print(f"  Min error: {np.abs(errors).min():.2f} GSM")

## 12. Test Set Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Predicted vs Actual
axes[0, 0].scatter(test_actuals, test_preds, alpha=0.6, s=80)
axes[0, 0].plot([test_actuals.min(), test_actuals.max()], 
                [test_actuals.min(), test_actuals.max()], 
                'r--', linewidth=2, label='Perfect Prediction')
axes[0, 0].fill_between([test_actuals.min(), test_actuals.max()],
                        [test_actuals.min()-5, test_actuals.max()-5],
                        [test_actuals.min()+5, test_actuals.max()+5],
                        alpha=0.2, color='green', label='¬±5 GSM')
axes[0, 0].set_xlabel('Actual GSM (g/m¬≤)', fontsize=12)
axes[0, 0].set_ylabel('Predicted GSM (g/m¬≤)', fontsize=12)
axes[0, 0].set_title(f'Predicted vs Actual GSM\n(R¬≤ = {test_r2:.4f})', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Residual plot
residuals = test_preds - test_actuals
axes[0, 1].scatter(test_actuals, residuals, alpha=0.6, s=80)
axes[0, 1].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[0, 1].axhline(y=5, color='g', linestyle=':', linewidth=2, alpha=0.5)
axes[0, 1].axhline(y=-5, color='g', linestyle=':', linewidth=2, alpha=0.5)
axes[0, 1].set_xlabel('Actual GSM (g/m¬≤)', fontsize=12)
axes[0, 1].set_ylabel('Residual (Predicted - Actual)', fontsize=12)
axes[0, 1].set_title('Residual Plot', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# 3. Error distribution
axes[1, 0].hist(residuals, bins=20, edgecolor='black', alpha=0.7)
axes[1, 0].axvline(x=0, color='r', linestyle='--', linewidth=2, label='Zero Error')
axes[1, 0].axvline(x=residuals.mean(), color='g', linestyle='-', linewidth=2, 
                   label=f'Mean: {residuals.mean():.2f}')
axes[1, 0].set_xlabel('Prediction Error (GSM)', fontsize=12)
axes[1, 0].set_ylabel('Frequency', fontsize=12)
axes[1, 0].set_title(f'Error Distribution\n(MAE = {test_mae:.3f} GSM)', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Absolute error vs actual GSM
abs_errors = np.abs(residuals)
axes[1, 1].scatter(test_actuals, abs_errors, alpha=0.6, s=80)
axes[1, 1].axhline(y=5, color='r', linestyle='--', linewidth=2, label='¬±5 GSM Target')
axes[1, 1].set_xlabel('Actual GSM (g/m¬≤)', fontsize=12)
axes[1, 1].set_ylabel('Absolute Error (GSM)', fontsize=12)
axes[1, 1].set_title('Absolute Error vs Actual GSM', fontsize=14, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{DATASET_PATH}/test_prediction_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Test visualization saved")

## 13. Save Model & Results

In [None]:
# Save model
model_save_path = f'{DATASET_PATH}/best_gsm_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'feature_cols': feature_cols,
    'scaler': scaler,
    'best_val_mae': best_val_mae,
    'test_mae': test_mae,
    'test_rmse': test_rmse,
    'test_r2': test_r2,
    'history': history
}, model_save_path)

print(f"‚úÖ Model saved to: {model_save_path}")

# Save predictions
results_df = df_test.copy()
results_df['predicted_gsm'] = test_preds
results_df['error'] = test_preds - test_actuals
results_df['abs_error'] = np.abs(test_preds - test_actuals)
results_df.to_csv(f'{DATASET_PATH}/test_predictions.csv', index=False)

print(f"‚úÖ Predictions saved")

# Save metrics summary
import json
metrics_summary = {
    'model': 'HybridGSMPredictor (EfficientNet-B3)',
    'total_params': total_params,
    'trainable_params': trainable_params,
    'num_features': len(feature_cols),
    'train_samples': len(df_train),
    'val_samples': len(df_val),
    'test_samples': len(df_test),
    'best_val_mae': float(best_val_mae),
    'test_mae': float(test_mae),
    'test_rmse': float(test_rmse),
    'test_r2': float(test_r2),
    'predictions_within_5gsm': float(within_5),
    'predictions_within_10gsm': float(within_10),
    'target_achieved': test_mae <= 5.0
}

with open(f'{DATASET_PATH}/model_metrics.json', 'w') as f:
    json.dump(metrics_summary, f, indent=2)

print(f"‚úÖ Metrics saved")

print("\n" + "="*80)
print("üéä ALL RESULTS SAVED!")
print("="*80)

## 14. Final Summary

In [None]:
print("\n" + "="*80)
print("üìä FINAL MODEL SUMMARY")
print("="*80)
print(f"\nüß† Model Architecture:")
print(f"  - Backbone: EfficientNet-B3 (ImageNet pretrained)")
print(f"  - Input: 224x224 RGB images + {len(feature_cols)} fabric features")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")

print(f"\nüìà Performance Metrics:")
print(f"  - Test MAE:  {test_mae:.3f} GSM")
print(f"  - Test RMSE: {test_rmse:.3f} GSM")
print(f"  - Test R¬≤:   {test_r2:.4f}")
print(f"  - Within ¬±5 GSM:  {within_5:.1f}%")
print(f"  - Within ¬±10 GSM: {within_10:.1f}%")

print(f"\nüìÅ Dataset:")
print(f"  - Total samples: {len(df_train) + len(df_val) + len(df_test)}")
print(f"  - Train: {len(df_train)} | Val: {len(df_val)} | Test: {len(df_test)}")
print(f"  - No augmentation used")

if test_mae <= 5.0:
    print(f"\nüéâ SUCCESS! Model achieves ¬±{test_mae:.2f} GSM accuracy")
else:
    print(f"\n‚ö†Ô∏è Model is {test_mae - 5:.2f} GSM away from ¬±5 GSM target")
    print(f"\nüí° Recommendations:")
    print(f"  - Collect more training samples")
    print(f"  - Try data augmentation")
    print(f"  - Experiment with different architectures")
    print(f"  - Fine-tune hyperparameters")

print("\n" + "="*80)
print("üèÅ TRAINING COMPLETE")
print("="*80)