In [None]:
import torch

# Check if MPS is available
print("MPS available:", torch.backends.mps.is_available())

# Check if MPS is built
print("MPS built:", torch.backends.mps.is_built())

# Check current device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Current device:", device)

<h4> # Import all modules

In [None]:
from data_loader import load_images, split_data
from preprocessing import preprocess_all_images,expand_channels_for_split,get_augmentation_transform,show_random_clahe_images_per_label,preprocess_image_cv2
from dataset import MergedImagesDataset
from model import get_resnet18_model,get_resnet18_model_layer_added, get_mobilenetv2_model,get_mobilenetv2_model_layer_added,get_efficientnetb0_model_layer_added,get_efficientnetv2_s_model_layer_added
from train import train_model_with_val
from evaluate import evaluate_model
from utils import EarlyStoppingWithLR, save_best_model_state
from grad_cam import show_grad_cam_for_random_images_per_label
import mlflow
import torch
from mlflow_log import log_after_evaluation, log_after_training, log_before_training


In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("Blade_Surface_Defect_Detection")

<h4> # 1. Load images and labels <h4>

In [None]:
images, labels = load_images()
#labels = ['Good' if l == 'Good' else 'Defective' for l in labels] ## Binary mapping


<h4> # 2. Preprocess images <h4>

In [None]:
clip_limit = 1.5
tile_grid_size = (3,3)

clahe_images = preprocess_all_images(images,clip_limit,tile_grid_size)

In [None]:
show_random_clahe_images_per_label(clahe_images, labels, n_per_label=2)

<h4> # 3. Split data. <h4>

In [None]:
x_train, x_val, x_test, y_train, y_val, y_test = split_data(list(clahe_images),labels)

<h4> # 4. Expand channels for compatibility with pretrained models <h4>

In [None]:
x_train_exp, x_val_exp, x_test_exp = expand_channels_for_split(x_train, x_val, x_test)

<h4> # 5. Prepare label mapping <h4>

In [None]:
unique_labels = sorted(set(labels))
label_to_idx = {label: idx for idx , label in enumerate(unique_labels)}

<h4> # 6. Data augmentation <h4>

In [None]:
transform = get_augmentation_transform()

<h4> Parameters <h4>

In [None]:
num_epochs=50
lr_rate=1e-4
batch_size=32

<h4> # 7. Create datasets and dataloaders

In [None]:
train_dataset = MergedImagesDataset(x_train_exp, y_train, label_to_idx, transform=transform)
val_dataset = MergedImagesDataset(x_val_exp, y_val, label_to_idx)
test_dataset = MergedImagesDataset(x_test_exp, y_test, label_to_idx)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)


<h4> # 8. Model, loss, optimizer, scheduler <h4>

In [None]:
device = torch.device('mps')
num_classes = len(unique_labels)
model = get_efficientnetv2_s_model_layer_added(num_classes, freeze=False).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr_rate, weight_decay=1e-5)
early_stopper = EarlyStoppingWithLR(optimizer, patience=4, lr_patience=2, factor=0.7)


<h4> # 9. Train and Evaluate <h4>


In [None]:
with mlflow.start_run(run_name="Blade Surface Defect Detection_clahe_test5") as run:

    log_before_training(num_epochs, lr_rate, batch_size, criterion, model, clip_limit, tile_grid_size)

    loss_history, acc_history, val_loss_history, val_acc_history = train_model_with_val(
        model, train_loader, val_loader, criterion, optimizer, device, num_epochs, early_stopper=early_stopper
    )
    save_best_model_state(early_stopper.best_state, "models/best_model_4.pth")
    log_after_training(model, test_loader, device, run.info.run_id)
    report = evaluate_model(model, test_loader, device, unique_labels)
    log_after_evaluation(report, unique_labels)


<h4> # 11 Grad-Cam <h4>

In [None]:
# Grad-CAM visualization for 2 random images per label (using numpy arrays)
num_classes = len(unique_labels)  # or the correct number for your task
model = get_efficientnetv2_s_model_layer_added(num_classes, freeze=False)
model.load_state_dict(torch.load("models/best_model_4.pth", map_location=device))
model.to(device)
model.eval()
# Get the target layer for ResNet18 (usually the last convolutional layer)
  # For ResNet18, this is typically the last conv layer

show_grad_cam_for_random_images_per_label(model, test_dataset, y_test, label_to_idx, device, n_per_label=2, )

