In [None]:
import torch
from torch import nn
import torchvision.models as models

backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor = feature_extractor.to("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from classes import VisionModule
from config import configuration

vision_obj = VisionModule(feature_extraction_model=feature_extractor, configuration=configuration)

In [None]:
import os

dataset_paths = [r"dataset/Keratoconus/images", r"dataset/normal/images"] 

outputs={}

for dataset_path in dataset_paths:
    class_name = os.path.basename(os.path.dirname(dataset_path))
    outputs[class_name] = vision_obj.run_vision_preprocessing(dataset_path)


In [None]:
import cv2

for class_name, info in outputs.items():

        print(f"\n[INFO] Processing class: {class_name}")

        crops_dir = info["crops"]                   
        processed_images = info["processed_images"] 
        class_base = os.path.dirname(crops_dir)     

        # Create output dirs
        deep_dir = os.path.join(class_base, "deep_features")
        hand_dir = os.path.join(class_base, "handcraft_features")

        os.makedirs(deep_dir, exist_ok=True)
        os.makedirs(hand_dir, exist_ok=True)

        for img_name in processed_images:

            img_id = os.path.splitext(img_name)[0]

            print(f"[INFO] Extracting features for: {img_id}")

            for q in ["Q1", "Q2", "Q3", "Q4"]:

                crop_path = os.path.join(crops_dir, f"{img_id}_{q}.png")

                if not os.path.exists(crop_path):
                    print(f"[WARN] Missing crop: {crop_path}")
                    continue

                # -----------------------------
                # Load crop
                # -----------------------------
                crop = cv2.imread(crop_path)
                crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)

                # -----------------------------
                # Deep features (CNN)
                # -----------------------------
                cnn_ready = vision_obj.preprocess_for_cnn(crop)
                deep_vec = vision_obj.extract_deep_features(
                    tensor=cnn_ready,
                    save_dir=deep_dir,
                    img_name=img_id,
                    quadrant=q
                )

                # -----------------------------
                # Handcrafted features
                # -----------------------------
                vision_obj.handcrafted_features(
                    cropped_img_pth=crop_path,
                    save_dir=hand_dir,
                    img_name=img_id,
                    quadrant=q
                )

        print(f"[INFO] Completed extracting features for class {class_name}")

In [None]:
from classes import TextModule

text_obj = TextModule()

In [None]:
import os

for class_name, info in outputs.items():

    print(f"\n[INFO] Processing class: {class_name}")

    crops_dir = info["crops"]
    processed_images = info["processed_images"]
    class_base = os.path.dirname(crops_dir)


    transformer_dir = os.path.join(class_base, "transformer_features")
    os.makedirs(transformer_dir, exist_ok=True)

    for img_name in processed_images:

        img_id = os.path.splitext(img_name)[0]
        print(f"[INFO] Extracting transformer features for: {img_id}")

        for q in ["Q1", "Q2", "Q3", "Q4"]:
            
            crop_path = os.path.join(crops_dir, f"{img_id}_{q}.png")

            if not os.path.exists(crop_path):
                print(f"[WARN] Missing crop: {crop_path}")
                continue

            # ---------------------------------------------------
            # Transformer features (ViT)
            # ---------------------------------------------------
            _ = text_obj.extract_transformer_features(
                img_path=crop_path,
                save_dir=transformer_dir,
                img_name=img_id,
                quadrant=q
            )

    print(f"[INFO] Completed transformer features for class {class_name}")

In [None]:
from classes import XGBoostTrainer

trainer = XGBoostTrainer(dataset_root="dataset")
xgb_model = trainer.train(r"saved_models/kc_classifier.pkl")