In [3]:
import os
import json
from typing import Dict, Any, Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 6  # 0: compliant, 1–5: non-compliant by rule
CLASS_TO_RULE = {
    0: ("compliant", None),
    1: ("non_compliant", "scaling"),
    2: ("non_compliant", "titles"),
    3: ("non_compliant", "color_misuse"),
    4: ("non_compliant", "axis_check"),
    5: ("non_compliant", "clutter_detection"),
}


## MobileNet
Create a MobileNetV3-Small model with a custom classifier for 'num_classes' outputs.

Create model and optionally load fine-tuned weights (.pth).


In [5]:
def create_mobilenet_rule_model(num_classes: int = NUM_CLASSES) -> nn.Module:
    model = models.mobilenet_v3_small(pretrained=True)
    # Default classifier is: [Linear, Hardswish, Dropout, Linear]
    in_features = model.classifier[3].in_features
    model.classifier[3] = nn.Linear(in_features, num_classes)
    return model

def load_mobilenet_model(checkpoint_path: Optional[str] = None) -> nn.Module:
    model = create_mobilenet_rule_model()
    if checkpoint_path is not None and os.path.isfile(checkpoint_path):
        state = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(state)
        print(f"Loaded model weights from: {checkpoint_path}")
    else:
        if checkpoint_path is not None:
            print(f"Checkpoint not found at {checkpoint_path}, using ImageNet weights only.")
    model.to(DEVICE)
    model.eval()
    return model



## Image Preprocessing

In [6]:
IMG_SIZE = 224

preprocess = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225],   # ImageNet std
    ),
])


def load_and_preprocess_image(image_path: str) -> torch.Tensor:
    img = Image.open(image_path).convert("RGB")
    tensor = preprocess(img)
    return tensor  # shape: (3, H, W)



## Prediction & Mapping

In [7]:
def predict_dashboard(
    model: nn.Module,
    image_tensor: torch.Tensor
) -> Dict[str, Any]:
    """
    Run inference on a preprocessed image tensor.
    Returns a dict with class_id, label, rule, confidence.
    """
    model.eval()
    image_tensor = image_tensor.to(DEVICE).unsqueeze(0)  # (1, 3, H, W)

    with torch.no_grad():
        logits = model(image_tensor)               # (1, NUM_CLASSES)
        probs = F.softmax(logits, dim=1)[0]       # (NUM_CLASSES,)
        pred_idx = int(torch.argmax(probs).item())
        confidence = float(probs[pred_idx].item())

    label, rule = CLASS_TO_RULE[pred_idx]

    return {
        "class_id": pred_idx,
        "label": label,
        "rule": rule,
        "confidence": confidence
    }



## Rule-based Feedback

In [8]:

def generate_feedback(rule: str, label: str, confidence: float, details: dict = None) -> dict:
    feedback = {
        "label": label,
        "confidence": round(confidence, 2),
        "feedback": []
    }
    
    # If compliant
    if label == 'Compliant':
        feedback["feedback"].append(
            f"Great! It follows the {rule if rule else 'IBCS'} rules. This dashboard meets the required compliance standards. Layout, readability, and structure appear suitable."
            f"Confidence: {confidence:.0%}."
        )
        return feedback
    
    # Non-compliant: Give specific, helpful suggestions
    suggestions: List[str] = []
    
    if rule == 'Scaling':
        suggestions.extend([
            "**Start bar charts at zero.** If your y-axis starts above zero, it can make small differences look huge.",
            "**Use the same scale for similar charts.** If you're comparing sales across regions, all charts should have identical y-axis ranges.",
            "**Show scale breaks clearly.** If you must skip part of the scale, use a visible break symbol (like ~) so readers know.",
            "**Label your units.** Add '(in EUR)', '(%)' or similar to your axis labels.",
            "**Avoid dual axes unless absolutely necessary.** Two different scales on one chart confuse readers."
        ])
        
        if details:
            violations = details.get('violations', [])
            if 'axis_misaligned' in violations:
                suggestions.insert(0, "**Axis alignment issue detected.** Check the areas highlighted in red.")
            if 'non_zero_start' in violations:
                suggestions.insert(0, "**Your y-axis doesn't start at zero.** This can mislead viewers.")
            if 'inconsistent_scale' in violations:
                suggestions.insert(0, "**Different charts use different scales.** Make them uniform for fair comparison.")
    
    elif rule == 'Titles':
        suggestions.extend([
            "\n**Use descriptive titles.** Good example: 'Monthly Revenue (EUR) - Q1 2025'. Bad example: 'Chart 1'.",
            "\n**Include the 5 W's:** What (Revenue), Where (Netherlands), When (January 2025), how much (in thousands).",
            "\n**Put titles at the top** of each chart, not inside it.",
            "\n**Keep it concise but clear.** Aim for one line if possible."
        ])
        
        if details and 'missing_title' in details.get('violations', []):
            suggestions.insert(0, "**Missing title detected.** Every chart needs a clear heading.")
    
    elif rule == 'Color_misuse':
        suggestions.extend([
            "**Use color sparingly.** Only highlight what matters—usually negatives (red) or key data points.",
            "**Stick to IBCS colors:** Blue/grey for normal data, red for negative, green for positive variance.",
            "**Avoid rainbow charts.** Too many colors make it hard to focus.",
            "**Test in grayscale.** If your chart doesn't make sense in black and white, you're relying too much on color."
        ])
        
        if details and 'excessive_colors' in details.get('violations', []):
            suggestions.insert(0, "**Too many colors detected.** Simplify your color palette.")
    
    elif rule == 'Axis_check':
        suggestions.extend([
            "**Label both axes clearly.** Include units like '(thousands)', '(%)' or '(days)'.",
            "**Use readable tick marks.** Not too many (cluttered) or too few (unclear).",
            "**Rotate labels if needed.** Long category names work better at 45° or vertically.",
            "**Remove unnecessary gridlines.** Keep only horizontal lines for bar charts, only vertical for column charts."
        ])
        
        if details and 'missing_units' in details.get('violations', []):
            suggestions.insert(0, "**Missing units on axis.** Add '(EUR)', '(%)' etc. to your labels.")
    
    elif rule == 'Clutter_detection':
        suggestions.extend([
            "**Remove decorative elements.** 3D effects, shadows, and borders don't add value.",
            "**Delete redundant legends.** If you only have one data series, label it directly on the chart.",
            "**Simplify backgrounds.** Use white or light grey—no patterns or gradients.",
            "**Cut unnecessary labels.** If every bar is labeled, you don't need y-axis tick marks."
        ])
        
        if details and 'excessive_elements' in details.get('violations', []):
            suggestions.insert(0, "**Too many visual elements.** Simplify for better readability.")
    
    else:
        suggestions.append(f"Unknown rule: '{rule}'. Please check your configuration.")
    

    feedback["feedback"] = suggestions
    return feedback
 

## Feedbacks Functions

In [9]:
def explain_dashboard(
    model: nn.Module,
    image_path: str,
    details_by_rule: Optional[Dict[str, Dict[str, Any]]] = None
) -> Dict[str, Any]:
    """
    High-level function:
    - loads & preprocesses image
    - gets model prediction
    - calls rule-based feedback
    """
    img_tensor = load_and_preprocess_image(image_path)
    pred = predict_dashboard(model, img_tensor)

    label = pred["label"]
    rule = pred["rule"]
    confidence = pred["confidence"]

    # For compliant: rule may be None -> use generic "IBCS"
    if label == "compliant":
        feedback = generate_feedback(
            rule=rule or "IBCS",
            label=label,
            confidence=confidence,
            details=None
        )
    else:
        rule_details = None
        if details_by_rule is not None and rule in details_by_rule:
            rule_details = details_by_rule[rule]

        feedback = generate_feedback(
            rule=rule,
            label=label,
            confidence=confidence,
            details=rule_details
        )

    result = {
        "prediction": pred,
        "feedback": feedback
    }
    return result


## Example 

In [None]:
if __name__ == "__main__":
    # Load model (optionally with your fine-tuned checkpoint)
    checkpoint = "./Checkpoints/mobilenet.keras"
    model = load_mobilenet_model(checkpoint_path=checkpoint)

    # Example: fake 'details_by_rule'
    fake_details = {
        "scaling": {"violations": ["non_zero_start", "axis_misaligned"]},
        "titles": {"violations": ["missing_title"]},
        "color_misuse": {"violations": ["excessive_colors"]},
        "axis_check": {"violations": ["missing_units"]},
        "clutter_detection": {"violations": ["excessive_elements"]},
    }

    # Run on one image
    test_image_path = "../test.webp"  # change to your image file
    result = explain_dashboard(model, test_image_path, details_by_rule=fake_details)

    
    print(json.dumps(result, indent=2))



RuntimeError: [enforce fail at inline_container.cc:176] . file in archive is not in a subdirectory: metadata.json

In [11]:
save_dir = "../Checkpoints"
os.makedirs(save_dir, exist_ok=True)

checkpoint_path = os.path.join(save_dir, "mobilenet_rules.pth")
torch.save(model.state_dict(), checkpoint_path)
print("Model checkpoint saved at:", checkpoint_path)

Model checkpoint saved at: ../Checkpoints\mobilenet_rules.pth
