**<h1>Building the final model</h1>**

# Import Libraries

In [1]:
import os
import shutil
import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
from pathlib import Path
from torchvision import transforms
import numpy as np

import torch
import torchvision.models as models
from ultralytics import YOLO
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2Tokenizer

import warnings
warnings.filterwarnings(action='ignore')

  from .autonotebook import tqdm as notebook_tqdm


# Define the configurations

In [2]:
CONFIG = {
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'YOLO_LOCATION_PATH': '../../models/final_model/damage_location_model/yolov8_damage_location_best.pt',
    'YOLO_SEVERITY_PATH': '../../models/final_model/damage_severity_model/yolov8_damage_severity_best.pt',
    'RESNET_CLASSIFIER_PATH': '../../models/final_model/classification_model/resnet50_classifier.pth',
    'CAPTIONING_MODEL': '../../models/final_model/caption_model',
    'SEVERITY_NAMES': ['low', 'medium', 'high'],
    'LOCATION_NAMES': ['front', 'back', 'rear-left', 'rear-right'],
    'TEST_DIR': '../../data/random_images',
    'MODEL_FINAL_SAVE_DIR': '../../models/inference_model'
}

In [3]:
img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Define the final model

In [4]:
class Resnet50CustomModel(torch.nn.Module):
    def __init__(self, dropout1=0.3, dropout2=0.4, dropout3=0.5, out1=1024, out2=512, out3=256):
        super().__init__()

        self.extractor = models.resnet50(pretrained=True)
        in_features = self.extractor.fc.in_features # Extract only the vector feature importance
        self.extractor.fc = torch.nn.Identity()
        
        self.mlp_layer = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_features, out_features=out1),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout1),

            torch.nn.Linear(in_features=out1, out_features=out2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout2),

            torch.nn.Linear(in_features=out2, out_features=out3),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout3),
        )
        
        self.classifier_head = torch.nn.Linear(in_features=out3, out_features=2)

    def forward(self, x):
        img_features = self.extractor(x)
        mlp_features = self.mlp_layer(img_features)
        prediction = self.classifier_head(mlp_features)

        return prediction

In [5]:
import torch
import numpy as np
from PIL import Image
from torchvision import models
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    GPT2Tokenizer
)
from ultralytics import YOLO


class DamageCaptioningModel(torch.nn.Module):
    def __init__(
        self,
        yolo_location_path: str,
        resnet_pth_path: str,
        yolo_severity_path: str,
        captioning_path: str,
        resnet_transform,
        severity_names,
        location_names,
        num_resnet_classes: int = 2,
        num_severity_classes: int = 3,
        damage_threshold: float = 0.3,
    ):
        super().__init__()

        self.device = CONFIG['DEVICE']
        self.damage_threshold = damage_threshold
        self.resnet_transform = resnet_transform
        self.severity_names = severity_names
        self.location_names = location_names

        # --------------------------------------------------
        # RESNET — DAMAGE / NO DAMAGE
        # --------------------------------------------------
        checkpoint = torch.load(CONFIG['RESNET_CLASSIFIER_PATH'], map_location=CONFIG['DEVICE'])

        self.resnet = Resnet50CustomModel(
            dropout1=checkpoint['config']['dropout1'],
            dropout2=checkpoint['config']['dropout2'],
            dropout3=checkpoint['config']['dropout3'],
            out1=checkpoint['config']['out1'],
            out2=checkpoint['config']['out2'],
            out3=checkpoint['config']['out3'],
        )

        self.resnet.load_state_dict(checkpoint['state_dict'])
        self.resnet.to(CONFIG['DEVICE']).eval()
        
        # --------------------------------------------------
        # YOLO MODELS
        # --------------------------------------------------
        self.yolo_location = YOLO(yolo_location_path)
        self.yolo_severity = YOLO(yolo_severity_path)

        # --------------------------------------------------
        # ViT + GPT2 Captioning
        # --------------------------------------------------
        self.vitgpt_model = VisionEncoderDecoderModel.from_pretrained(
            captioning_path
        ).to(CONFIG['DEVICE']).eval()

        self.vitgpt_processor = ViTImageProcessor.from_pretrained(
            captioning_path
        )

        self.vitgpt_tokenizer = GPT2Tokenizer.from_pretrained(
            captioning_path
        )

    # --------------------------------------------------
    # DAMAGE CLASSIFICATION
    # --------------------------------------------------
    @torch.no_grad()
    def _predict_damage(self, image: Image.Image):
        x = self.resnet_transform(image).unsqueeze(0).to(self.device)
        logits = self.resnet(x)
        probs = torch.softmax(logits, dim=1)
        damaged_prob = probs[0, 1].item()
        is_damaged = damaged_prob > self.damage_threshold
        return is_damaged, damaged_prob

    # --------------------------------------------------
    # YOLO HELPER
    # --------------------------------------------------
    @staticmethod
    def _extract_best_detection(yolo_result):
        if yolo_result.boxes is None or len(yolo_result.boxes) == 0:
            return None, None, None

        idx = yolo_result.boxes.conf.argmax()
        box = yolo_result.boxes.xyxy[idx].cpu().numpy()
        cls = int(yolo_result.boxes.cls[idx])
        conf = float(yolo_result.boxes.conf[idx].item())
        return box, cls, conf
    
    @torch.no_grad()
    def _predict_severity(self, image: Image.Image):
        result = self.yolo_severity(image)[0]

        if result.boxes is None or len(result.boxes) == 0:
            return "unknown", None

        idx = result.boxes.conf.argmax()
        cls = int(result.boxes.cls[idx])
        box = result.boxes.xyxy[idx].cpu().numpy()

        severity_name = self.yolo_severity.names[cls]
        return severity_name, box
    

    # --------------------------------------------------
    # FORWARD
    # --------------------------------------------------
    @torch.no_grad()
    def forward(self, image: Image.Image):

        if image.mode != "RGB":
            image = image.convert("RGB")

        np_image = np.array(image)

        # DAMAGE CLASSIFICATION
        is_damaged, damage_prob = self._predict_damage(image)

        # IMAGE CAPTIONING (ALWAYS)
        pixel_values = self.vitgpt_processor(
            images=image,
            return_tensors="pt"
        ).pixel_values.to(self.device)

        output_ids = self.vitgpt_model.generate(
            pixel_values,
            max_length=50,
            num_beams=5
        )

        raw_caption = self.vitgpt_tokenizer.decode(
            output_ids[0],
            skip_special_tokens=True
        )

        # NO DAMAGE CASE
        if not is_damaged:
            return {
                "image_pil": image,
                "image_np": np_image,

                "damaged": False,
                "damage_probability": damage_prob,

                "caption": f"A car with no visible damage. {raw_caption}",

                "location": None,
                "severity": None,
                "location_box": None,
                "severity_box": None,
            }

        # YOLO LOCATION
        loc_result = self.yolo_location(image)[0]
        loc_box, loc_cls, _ = self._extract_best_detection(loc_result)

        location_name = (
            self.yolo_location.names[loc_cls]
            if loc_cls is not None else "unknown"
        )

        # SEVERITY
        severity_name, severity_box = self._predict_severity(image)

        # FINAL CAPTION
        final_caption = (
            f"A car with {location_name} damage of {severity_name} severity. "
            f"{raw_caption}"
        )

        return {
            "image_pil": image,
            "image_np": np_image,

            "damaged": True,
            "damage_probability": damage_prob,

            "location": location_name,
            "severity": severity_name,
            "location_box": loc_box,
            "severity_box": severity_box,

            "caption": final_caption,
        }


# Define helper functions

In [6]:
def save_model(model, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    # -------------------------
    # SAVE YOLO CHECKPOINTS LOCALLY
    # -------------------------
    yolo_loc_path = os.path.join(save_dir, "yolo_location.pt")
    yolo_sev_path = os.path.join(save_dir, "yolo_severity.pt")

    shutil.copy(model.yolo_location.ckpt_path, yolo_loc_path)
    shutil.copy(model.yolo_severity.ckpt_path, yolo_sev_path)

    # -------------------------
    # SAVE CAPTION MODEL
    # -------------------------
    caption_dir = os.path.join(save_dir, "caption_model")
    model.vitgpt_model.save_pretrained(caption_dir)
    model.vitgpt_processor.save_pretrained(caption_dir)
    model.vitgpt_tokenizer.save_pretrained(caption_dir)

    # -------------------------
    # SAVE RESNET (FULL CHECKPOINT)
    # -------------------------
    resnet_payload = {
        "state_dict": model.resnet.state_dict(),
        "config": {
            "dropout1": model.resnet.mlp_layer[2].p,
            "dropout2": model.resnet.mlp_layer[5].p,
            "dropout3": model.resnet.mlp_layer[8].p,
            "out1": model.resnet.mlp_layer[0].out_features,
            "out2": model.resnet.mlp_layer[3].out_features,
            "out3": model.resnet.mlp_layer[6].out_features,
        }
    }

    torch.save(
        resnet_payload,
        os.path.join(save_dir, "resnet_damage.pth")
    )

    # -------------------------
    # SAVE META
    # -------------------------
    meta = {
        "yolo_location_ckpt": f"models/inference_model/yolo_location.pt",
        "yolo_severity_ckpt": f"models/inference_model/yolo_severity.pt",
        "caption_model_dir": f"models/inference_model/caption_model",
        "damage_threshold": model.damage_threshold,
    }

    torch.save(meta, os.path.join(save_dir, "meta.pt"))

    print(f"✅ Model saved to: {save_dir}")

In [7]:
def visualize_result(image=None, result=None):
    """
    image  : PIL.Image or None
    result : dict returned by DamageCaptioningModel
    """

    # -------------------------
    # IMAGE SOURCE
    # -------------------------
    if image is None:
        image = result.get("image_pil", None)

    if image is None:
        raise ValueError("No image provided to visualize.")

    # -------------------------
    # CREATE FIGURE
    # -------------------------
    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image)

    damaged = result.get("damaged", False)

    # -------------------------
    # DRAW BOX HELPER
    # -------------------------
    def draw(box, label, color):
        if box is None:
            return

        x1, y1, x2, y2 = box
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=2,
            edgecolor=color,
            facecolor="none"
        )
        ax.add_patch(rect)

        ax.text(
            x1,
            max(y1 - 10, 0),
            label,
            color="white",
            fontsize=11,
            weight="bold",
            bbox=dict(facecolor=color, alpha=0.7, pad=2)
        )

    # -------------------------
    # DRAW YOLO RESULTS
    # -------------------------
    if damaged:
        # LOCATION BOX
        draw(
            result.get("location_box"),
            f"Location: {result.get('location', 'unknown')}",
            "red"
        )

        # SEVERITY BOX
        draw(
            result.get("severity_box"),
            f"Severity: {result.get('severity', 'unknown')}",
            "blue"
        )

    # -------------------------
    # TITLE & CLEANUP
    # -------------------------
    caption = result.get("caption", "")
    ax.set_title(caption, fontsize=12, wrap=True)

    ax.axis("off")
    plt.tight_layout()
    plt.show()


# Loading, saving, and testing the final model

## Loading the model

In [8]:
model = DamageCaptioningModel(
    yolo_location_path=CONFIG['YOLO_LOCATION_PATH'],
    resnet_pth_path=CONFIG['RESNET_CLASSIFIER_PATH'],
    yolo_severity_path=CONFIG['YOLO_SEVERITY_PATH'],
    captioning_path=CONFIG['CAPTIONING_MODEL'],
    resnet_transform=img_transform,
    severity_names=CONFIG['SEVERITY_NAMES'],
    location_names=CONFIG['LOCATION_NAMES'],
    damage_threshold=0.5
)

## Saving the model

In [10]:
# Save model
save_model(model, CONFIG['MODEL_FINAL_SAVE_DIR'])

✅ Model saved to: ../../models/inference_model
