## GSM Fine-Tuning with Augmented Features Dataset

This notebook fine-tunes a pre-trained GSM regressor using the augmented features dataset with:
- **Small learning rate** to preserve previously learned features
- **All extracted features** (weft/warp counts, texture, color, etc.)
- **Error tolerance of ¬±5 GSM** based on weave pattern closeness
- **Transfer learning** from pre-trained model mounted in Google Drive

### Quick Start:
1. Runtime ‚Üí Change runtime type ‚Üí GPU (T4 preferred)
2. Section 1: Mount Google Drive and configure paths
3. Section 2: Load pre-trained model
4. Section 3: Prepare augmented features dataset
5. Section 4-6: Fine-tune model with small learning rate
6. Section 7: Evaluate and save results

# Section 1: Setup and Mount Google Drive

In [None]:
# Check GPU and Set Reproducibility
import os
import random
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

set_seed(42)

if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0)
    print(f"‚úÖ CUDA available: {gpu_name}")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è  CUDA not available. Running on CPU.")

print(f"PyTorch version: {torch.__version__}")

In [None]:
# Mount Google Drive
try:
    from google.colab import drive
    drive.mount("/content/drive")
    IN_COLAB = True
    print("‚úÖ Google Drive mounted")
except Exception as e:
    IN_COLAB = False
    print(f"‚ö†Ô∏è  Not in Colab environment: {e}")
    print("Using local paths instead")

In [None]:
# ============================================
# Configuration and Paths
# ============================================

# FOR COLAB: Set these paths to your Google Drive locations
if IN_COLAB:
    DRIVE_ROOT = "/content/drive/MyDrive"
    PRETRAINED_MODEL_PATH = os.path.join(DRIVE_ROOT, "path_to_pretrained_model.pt")  # TODO: update this
    OUTPUT_DIR = os.path.join(DRIVE_ROOT, "GSM_Finetuned_Model")
else:
    # LOCAL: Update these to your local paths
    LOCAL_ROOT = r"c:\Users\I769816\Desktop\GSM_fabric\fabric_gsm_pipeline"
    PRETRAINED_MODEL_PATH = os.path.join(LOCAL_ROOT, "Model", "best_model (1).pt")
    OUTPUT_DIR = os.path.join(LOCAL_ROOT, "train", "finetuned_model")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Dataset paths
if IN_COLAB:
    FEATURES_DATASET_PATH = os.path.join(DRIVE_ROOT, "augmented_features_dataset")
else:
    FEATURES_DATASET_PATH = os.path.join(LOCAL_ROOT, "augmented_features_dataset")

# CSV files
TRAIN_CSV = os.path.join(FEATURES_DATASET_PATH, "dataset_train.csv")
VAL_CSV = os.path.join(FEATURES_DATASET_PATH, "dataset_val.csv")
TEST_CSV = os.path.join(FEATURES_DATASET_PATH, "dataset_test.csv")
IMAGES_DIR = os.path.join(FEATURES_DATASET_PATH, "images")

print("üìÅ Configuration:")
print(f"  Pretrained Model: {PRETRAINED_MODEL_PATH}")
print(f"  Train CSV: {TRAIN_CSV}")
print(f"  Output Dir: {OUTPUT_DIR}")
print(f"  Images Dir: {IMAGES_DIR}")

# Verify paths exist
assert os.path.exists(PRETRAINED_MODEL_PATH), f"Model not found: {PRETRAINED_MODEL_PATH}"
assert os.path.exists(TRAIN_CSV), f"Train CSV not found: {TRAIN_CSV}"
assert os.path.exists(VAL_CSV), f"Val CSV not found: {VAL_CSV}"
assert os.path.exists(TEST_CSV), f"Test CSV not found: {TEST_CSV}"
print("‚úÖ All paths verified")

# Section 2: Load Pre-trained Model

In [None]:
# Define CNN Regressor Architecture (same as original)
import timm

class Regressor(nn.Module):
    def __init__(self, backbone_name='efficientnet_b3', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        in_feats = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(in_feats, 1)
        )

    def forward(self, x):
        f = self.backbone(x)
        out = self.head(f).squeeze(1)
        return out

# Load the pre-trained model
print("Loading pre-trained model...")
model = Regressor(backbone_name='efficientnet_b3', pretrained=False).to(device)

checkpoint = torch.load(PRETRAINED_MODEL_PATH, map_location=device)

# Handle different checkpoint formats
if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
    model.load_state_dict(checkpoint['model_state'])
    print(f"‚úÖ Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"   Previous Val MAE: {checkpoint.get('val_mae', 'N/A')}")
else:
    model.load_state_dict(checkpoint)
    print("‚úÖ Loaded model weights")

print(f"üìä Model: {model.__class__.__name__}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Section 3: Prepare Augmented Features Dataset

In [None]:
# Load CSV files
print("Loading dataset CSVs...")
train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)
test_df = pd.read_csv(TEST_CSV)

print(f"‚úÖ Train: {len(train_df)} samples")
print(f"‚úÖ Val:   {len(val_df)} samples")
print(f"‚úÖ Test:  {len(test_df)} samples")

# Display columns
print(f"\nüìä Dataset Columns ({len(train_df.columns)}):")
print(train_df.columns.tolist())

# Display basic stats
print(f"\nüìà GSM Statistics:")
print(train_df['gsm'].describe())

In [None]:
# Extract feature columns (exclude image_name, gsm, source, augmentation, original_image, split)
exclude_cols = {'image_name', 'gsm', 'source', 'augmentation', 'original_image', 'split'}
feature_cols = [c for c in train_df.columns if c not in exclude_cols]

print(f"üìä Feature Columns ({len(feature_cols)}):")
for i, col in enumerate(feature_cols, 1):
    print(f"  {i:2d}. {col}")

# Check for missing values
print(f"\n‚ö†Ô∏è  Missing values in train:")
missing = train_df[feature_cols].isnull().sum()
if missing.sum() > 0:
    print(missing[missing > 0])
else:
    print("  None!")

In [None]:
# Normalize features using StandardScaler
from sklearn.preprocessing import StandardScaler

print("Normalizing features...")
scaler = StandardScaler()
train_features = scaler.fit_transform(train_df[feature_cols])
val_features = scaler.transform(val_df[feature_cols])
test_features = scaler.transform(test_df[feature_cols])

# Save scaler for later use
import pickle
scaler_path = os.path.join(OUTPUT_DIR, 'feature_scaler.pkl')
with open(scaler_path, 'wb') as f:
    pickle.dump(scaler, f)
print(f"‚úÖ Scaler saved to {scaler_path}")

print(f"\nüìä Normalized feature shapes:")
print(f"  Train: {train_features.shape}")
print(f"  Val:   {val_features.shape}")
print(f"  Test:  {test_features.shape}")

# Section 4: Create Feature Dataset Classes

In [None]:
# Feature-based Dataset (no images, pure features)
class FeaturesDataset(Dataset):
    def __init__(self, features, targets, names=None):
        """
        features: numpy array of shape (N, num_features)
        targets: numpy array of shape (N,) with GSM values
        names: optional list of image names
        """
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.names = names if names is not None else ['unknown'] * len(features)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx], self.names[idx]

# Create datasets
train_dataset = FeaturesDataset(
    train_features,
    train_df['gsm'].values,
    train_df['image_name'].tolist()
)

val_dataset = FeaturesDataset(
    val_features,
    val_df['gsm'].values,
    val_df['image_name'].tolist()
)

test_dataset = FeaturesDataset(
    test_features,
    test_df['gsm'].values,
    test_df['image_name'].tolist()
)

print(f"‚úÖ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

# Test a sample
feat_sample, gsm_sample, name_sample = train_dataset[0]
print(f"\nüìä Sample feature shape: {feat_sample.shape}, GSM: {gsm_sample.item():.2f}, Name: {name_sample}")

In [None]:
# Create DataLoaders
BATCH_SIZE = 32  # Can be larger with features-only
NUM_WORKERS = 0  # Features don't need multiprocessing

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

print(f"‚úÖ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

# Section 5: Fine-tuning Configuration (Small Learning Rate)

In [None]:
print("‚öôÔ∏è  Fine-tuning Configuration:")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Weight Decay: {WEIGHT_DECAY}")
print(f"  Epochs: {EPOCHS}")
print(f"  Early Stopping: DISABLED (will train full {EPOCHS} epochs)")
print(f"  Best Model: Will be saved automatically")
print(f"  Error Tolerance: ¬±{ERROR_TOLERANCE} GSM")

# Section 6: Fine-tuning Training Loop

In [None]:
# ============================================
# Training Loop
# ============================================

best_val_mae = float('inf')
history = defaultdict(list)

best_ckpt_path = os.path.join(OUTPUT_DIR, 'best_finetuned_model.pt')

print("üöÄ Starting fine-tuning (200 epochs, no early stopping)...\\n")

for epoch in range(1, EPOCHS + 1):
    # -------- Training --------
    model.train()
    train_losses = []

    for features, targets, _ in tqdm(train_loader, desc=f"Epoch {epoch:3d}/{EPOCHS} [Train]", leave=False):
        features = features.to(device)
        targets = targets.to(device)

        optimizer.zero_grad(set_to_none=True)

        # Forward pass
        preds = model(features)
        loss = criterion(preds, targets)

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_losses.append(loss.item())

    # -------- Validation --------
    model.eval()
    val_preds = []
    val_trues = []

    with torch.no_grad():
        for features, targets, _ in val_loader:
            features = features.to(device)
            targets = targets.to(device)

            preds = model(features)
            val_preds.append(preds.cpu())
            val_trues.append(targets.cpu())

    val_preds = torch.cat(val_preds)
    val_trues = torch.cat(val_trues)

    # Metrics
    train_loss = float(np.mean(train_losses))
    val_mae = float(torch.mean(torch.abs(val_preds - val_trues)))
    val_rmse = float(torch.sqrt(torch.mean((val_preds - val_trues) ** 2)))

    history['train_loss'].append(train_loss)
    history['val_mae'].append(val_mae)
    history['val_rmse'].append(val_rmse)

    # -------- Learning Rate Warmup --------
    if epoch <= WARMUP_EPOCHS:
        lr = LEARNING_RATE * epoch / max(1, WARMUP_EPOCHS)
        for g in optimizer.param_groups:
            g['lr'] = lr
    else:
        scheduler.step()

    lr = optimizer.param_groups[0]['lr']

    # -------- Logging --------
    print(
        f"Epoch {epoch:3d}/{EPOCHS} | "
        f"Train Loss: {train_loss:.6f} | "
        f"Val MAE: {val_mae:.4f} | "
        f"Val RMSE: {val_rmse:.4f} | "
        f"LR: {lr:.2e}"
    )

    # -------- Save best checkpoint (No early stopping) --------
    if val_mae < best_val_mae - 1e-4:
        best_val_mae = val_mae
        torch.save(
            {
                'model_state': model.state_dict(),
                'epoch': epoch,
                'val_mae': best_val_mae,
                'val_rmse': val_rmse,
                'train_loss': train_loss
            },
            best_ckpt_path
        )
        print(f"  ‚úÖ Saved best model (Val MAE = {best_val_mae:.4f})")

print(f"\n‚úÖ Training complete!")
print(f"   Best Val MAE achieved: {best_val_mae:.4f}")
print(f"   Total epochs trained: {EPOCHS}")
print(f"   Model saved to: {best_ckpt_path}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

axes[1].plot(history['val_mae'], label='Val MAE', linewidth=2, color='orange')
axes[1].plot(history['val_rmse'], label='Val RMSE', linewidth=2, color='red')
axes[1].axhline(y=ERROR_TOLERANCE, color='green', linestyle='--', label=f'Error Tolerance (¬±{ERROR_TOLERANCE})')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Error')
axes[1].set_title('Validation Metrics')
axes[1].grid(True, alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Training curves saved")

# Section 7: Evaluation on Test Set

In [None]:
# Load best model
print("Loading best fine-tuned model...")
checkpoint = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(checkpoint['model_state'])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch']}")
print(f"   Val MAE: {checkpoint['val_mae']:.4f}")
print(f"   Val RMSE: {checkpoint['val_rmse']:.4f}")

In [None]:
# Evaluate on test set
print("Evaluating on test set...\n")
model.eval()

test_preds = []
test_trues = []
test_names = []

with torch.no_grad():
    for features, targets, names in tqdm(test_loader, desc="Evaluating"):
        features = features.to(device)
        targets = targets.to(device)

        preds = model(features)
        test_preds.append(preds.cpu())
        test_trues.append(targets.cpu())
        test_names.extend(names)

test_preds = torch.cat(test_preds).numpy()
test_trues = torch.cat(test_trues).numpy()

# Metrics
test_mae = mean_absolute_error(test_trues, test_preds)
test_rmse = np.sqrt(mean_squared_error(test_trues, test_preds))
test_r2 = r2_score(test_trues, test_preds)

# Error tolerance analysis
test_errors = np.abs(test_preds - test_trues)
within_tolerance = (test_errors <= ERROR_TOLERANCE).mean() * 100

print("\nüìä Test Set Metrics:")
print(f"  MAE:  {test_mae:.4f} (Mean Absolute Error)")
print(f"  RMSE: {test_rmse:.4f} (Root Mean Squared Error)")
print(f"  R¬≤:   {test_r2:.4f}")
print(f"\n‚úÖ Within ¬±{ERROR_TOLERANCE} tolerance: {within_tolerance:.1f}% ({int(within_tolerance/100 * len(test_trues))} / {len(test_trues)})")

In [None]:
# Error distribution by weave pattern (if available in data)
test_df_copy = test_df.copy()
test_df_copy['pred'] = test_preds
test_df_copy['error'] = test_errors
test_df_copy['within_tolerance'] = (test_errors <= ERROR_TOLERANCE).astype(int)

print("\nüìà Error Statistics by Source (Weave Pattern):")
if 'source' in test_df_copy.columns:
    for source in test_df_copy['source'].unique():
        mask = test_df_copy['source'] == source
        src_errors = test_df_copy[mask]['error'].values
        src_tolerance = test_df_copy[mask]['within_tolerance'].mean() * 100
        print(f"  {source:15s}: MAE={src_errors.mean():.4f}, Within tolerance={src_tolerance:.1f}%")

# Save detailed predictions
output_csv = os.path.join(OUTPUT_DIR, 'test_predictions.csv')
test_df_copy.to_csv(output_csv, index=False)
print(f"\n‚úÖ Predictions saved to {output_csv}")

In [None]:
# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Predictions vs Ground Truth
axes[0, 0].scatter(test_trues, test_preds, alpha=0.6, s=50)
min_val = min(test_trues.min(), test_preds.min())
max_val = max(test_trues.max(), test_preds.max())
axes[0, 0].plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')
axes[0, 0].fill_between(
    [min_val, max_val],
    [min_val - ERROR_TOLERANCE, max_val - ERROR_TOLERANCE],
    [min_val + ERROR_TOLERANCE, max_val + ERROR_TOLERANCE],
    alpha=0.2, color='green', label=f'¬±{ERROR_TOLERANCE} Tolerance'
)
axes[0, 0].set_xlabel('Ground Truth GSM')
axes[0, 0].set_ylabel('Predicted GSM')
axes[0, 0].set_title(f'Predictions vs Ground Truth (R¬≤={test_r2:.4f})')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Error distribution
axes[0, 1].hist(test_errors, bins=30, alpha=0.7, color='blue', edgecolor='black')
axes[0, 1].axvline(ERROR_TOLERANCE, color='green', linestyle='--', linewidth=2, label=f'Tolerance={ERROR_TOLERANCE}')
axes[0, 1].axvline(test_mae, color='red', linestyle='--', linewidth=2, label=f'MAE={test_mae:.4f}')
axes[0, 1].set_xlabel('Absolute Error (GSM)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Error Distribution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Residuals
residuals = test_trues - test_preds
axes[1, 0].scatter(test_preds, residuals, alpha=0.6, s=50)
axes[1, 0].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[1, 0].fill_between([test_preds.min(), test_preds.max()], -ERROR_TOLERANCE, ERROR_TOLERANCE, alpha=0.2, color='green')
axes[1, 0].set_xlabel('Predicted GSM')
axes[1, 0].set_ylabel('Residuals (True - Pred)')
axes[1, 0].set_title('Residual Plot')
axes[1, 0].grid(True, alpha=0.3)

# 4. Cumulative tolerance
sorted_errors = np.sort(test_errors)
cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors) * 100
axes[1, 1].plot(sorted_errors, cumulative, linewidth=2)
axes[1, 1].axvline(ERROR_TOLERANCE, color='green', linestyle='--', linewidth=2, label=f'¬±{ERROR_TOLERANCE} GSM')
axes[1, 1].set_xlabel('Absolute Error (GSM)')
axes[1, 1].set_ylabel('Cumulative %')
axes[1, 1].set_title('Cumulative Error Distribution')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'evaluation_results.png'), dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Evaluation plots saved")

# Section 8: Export and Summary

In [None]:
# Export final model and metrics
final_model_path = os.path.join(OUTPUT_DIR, 'gsm_regressor_finetuned.pt')
torch.save(model.state_dict(), final_model_path)
print(f"‚úÖ Model saved to {final_model_path}")

# Save metrics summary
metrics_summary = {
    'train': {
        'samples': len(train_df),
        'final_loss': float(history['train_loss'][-1])
    },
    'val': {
        'samples': len(val_df),
        'mae': float(history['val_mae'][-1]),
        'best_mae': best_val_mae
    },
    'test': {
        'samples': len(test_df),
        'mae': float(test_mae),
        'rmse': float(test_rmse),
        'r2': float(test_r2),
        'within_tolerance': float(within_tolerance)
    },
    'config': {
        'learning_rate': LEARNING_RATE,
        'epochs': len(history['train_loss']),
        'batch_size': BATCH_SIZE,
        'error_tolerance': ERROR_TOLERANCE,
        'num_features': len(feature_cols)
    }
}

import json
summary_path = os.path.join(OUTPUT_DIR, 'metrics_summary.json')
with open(summary_path, 'w') as f:
    json.dump(metrics_summary, f, indent=2)
print(f"‚úÖ Metrics saved to {summary_path}")

In [None]:
print("\\n" + "="*70)
print("üéØ FINE-TUNING SUMMARY")
print("="*70)
print(f"\\nüìä Dataset:")
print(f"  Train samples:  {len(train_df)} (with {len(train_df.columns)-6} extracted features)")
print(f"  Val samples:    {len(val_df)}")
print(f"  Test samples:   {len(test_df)}")

print(f"\\n‚öôÔ∏è  Configuration:")
print(f"  Learning Rate:  {LEARNING_RATE} (Small for transfer learning)")
print(f"  Total Epochs:   {EPOCHS} (No early stopping)")
print(f"  Batch Size:     {BATCH_SIZE}")
print(f"  Error tolerance: ¬±{ERROR_TOLERANCE} GSM")

print(f"\\nüìà Results:")
print(f"  Best Val MAE:   {best_val_mae:.4f}")
print(f"  Test MAE:       {test_mae:.4f}")
print(f"  Test RMSE:      {test_rmse:.4f}")
print(f"  Test R¬≤:        {test_r2:.4f}")
print(f"  ‚úÖ Within tolerance: {within_tolerance:.1f}%")

print(f"\\nüíæ Output files:")
print(f"  Best model:     {best_ckpt_path}")
print(f"  Final model:    {final_model_path}")
print(f"  Predictions:    {output_csv}")
print(f"  Metrics:        {summary_path}")
print(f"  Scaler:         {scaler_path}")
print(f"  Plots:          {os.path.join(OUTPUT_DIR, 'training_curves.png')}")
print(f"                  {os.path.join(OUTPUT_DIR, 'evaluation_results.png')}")
print(f"\\n‚úÖ Fine-tuning complete! (200 epochs trained, best model saved)")
print("="*70)