# **Imports & Configuration**

In [9]:
# ======================== IMPORTS & CONFIG ========================
import os
import re
import numpy as np
from glob import glob
from PIL import Image
import json
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, models

import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
# from tensorflow.keras.applications.xception import preprocess_input
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from tensorflow.keras.layers import UnitNormalization

# ================= CONFIG =================
TEST_IMAGES_FOLDER = "TestCases/Integerated Test"
OUTPUT_ROOT = "Integrated_Test_Output"

# ---------- Stage 1: Food vs Fruit ----------
FOOD_FRUIT_MODEL_PATH = "Models/part_a_best_mobilenet.pth"
IMG_SIZE_FF = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Food Recognition (Siamese) ----------
SIAMESE_MODEL_PATH = "Models/PartB_EfficientNetB2.h5"
TRAIN_DIR_FOOD = "Project Data/Food/Train"
VALID_DIR_FOOD = "Project Data/Food/Validation"


IMG_SIZE_SIAMESE = (300, 300)
SAMPLES_PER_CLASS = 10

# ---------- Fruit Classification ----------
FRUIT_MODEL_PATH = "Models/MobileNetV2_PartC.keras"
TRAIN_DIR_FRUIT = "Project Data/Fruit/Train"
IMG_SIZE_FRUIT = 350

# ---------- Calories (SEPARATE FILES) ----------
FOOD_CALORIES_FILE_TRAIN = "Project Data/Food/Train Calories.txt"
FOOD_CALORIES_FILE_VALID = "Project Data/Food/Val Calories.txt"
FRUIT_CALORIES_FILE = "Project Data/Fruit/Calories.txt"

# ---------- Binary Segmentation ----------
SEG_MODEL_PATH = "Models/segnet_best.keras"
SEG_IMAGE_SIZE = (256, 256)

# ---------- Multi Segmentation ----------
MULTI_SEG_MODEL_PATH = "Models/best_multiclass_resnet50_unet.keras"
CLASS_MAPPING_FILE = "Models/class_mapping.json"
COLOR_MAPPING_FILE = "Models/color_mapping.json"
MULTI_SEG_IMG_SIZE = 224
NUM_CLASSES = 31



# **Food VS Fruit**

In [10]:
class FoodFruitClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.mobilenet = models.mobilenet_v2(weights=None)
        self.mobilenet.classifier[1] = nn.Linear(
            self.mobilenet.classifier[1].in_features, 2
        )

    def forward(self, x):
        return self.mobilenet(x)


food_fruit_model = FoodFruitClassifier()

checkpoint = torch.load(FOOD_FRUIT_MODEL_PATH, map_location=DEVICE)

# ✅ This will now load correctly
food_fruit_model.load_state_dict(checkpoint["model_state_dict"])

food_fruit_model.to(DEVICE)
food_fruit_model.eval()

transform_ff = transforms.Compose([
    transforms.Resize((IMG_SIZE_FF, IMG_SIZE_FF)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def predict_food_or_fruit(img_path):
    img = Image.open(img_path).convert("RGB")
    img = transform_ff(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = food_fruit_model(img)
        pred = torch.argmax(out, dim=1).item()
    return "Food" if pred == 0 else "Fruit"


# **Food Recognition**

In [11]:
# # Load Siamese Encoder
encoder = load_model(SIAMESE_MODEL_PATH)

def get_embedding(img_path):
    img = load_img(img_path, target_size=IMG_SIZE_SIAMESE)
    img = img_to_array(img)
    img = preprocess_input(np.expand_dims(img, axis=0))
    return encoder.predict(img, verbose=0)[0]

# Load representatives from Train + Valid, one per class
def load_food_representatives(dirs=[VALID_DIR_FOOD, TRAIN_DIR_FOOD], samples_per_class=SAMPLES_PER_CLASS):
    reps = {}
    for d in dirs:
        if not os.path.exists(d):
            continue
        for cls in os.listdir(d):
            cls_path = os.path.join(d, cls)
            if not os.path.isdir(cls_path) or cls in reps:
                continue  # skip if already added
            imgs = glob(os.path.join(cls_path, "*.jpg"))[:samples_per_class]
            if imgs:
                reps[cls] = [get_embedding(imgs[0])]
    return reps

food_embeddings = load_food_representatives()

def recognize_food(img_path):
    emb = get_embedding(img_path)
    distances = {cls: np.mean([np.linalg.norm(emb - e) for e in embs])
                 for cls, embs in food_embeddings.items()}
    return min(distances, key=distances.get)








In [12]:
# import open_clip
# from PIL import Image

# # Load your fine-tuned CLIP model
# FINETUNED_MODEL_PATH = "Models/clip_finetuned_food_improved.pth"
# clip_model, _, preprocess = open_clip.create_model_and_transforms(
#     model_name="ViT-B-32",
#     pretrained="openai"
# )

# checkpoint = torch.load(FINETUNED_MODEL_PATH, map_location=DEVICE)
# clip_model.load_state_dict(checkpoint["clip_model_state_dict"])

# clip_model = clip_model.to(DEVICE).eval()
# print("✓ CLIP model loaded successfully!")

# # Load Food classes from Train + Valid
# food_classes = sorted(
#     set(os.listdir(TRAIN_DIR_FOOD)) | set(os.listdir(VALID_DIR_FOOD))
# )

# def recognize_food_clip(img_path, classes=food_classes):
#     """Recognize food class using fine-tuned CLIP"""
#     img = preprocess(Image.open(img_path)).unsqueeze(0).to(DEVICE)
    
#     with torch.no_grad():
#         img_features = clip_model.encode_image(img)
#         img_features /= img_features.norm(dim=-1, keepdim=True)

#         # Tokenize text labels
#         text_inputs = open_clip.tokenize(classes).to(DEVICE)
#         text_features = clip_model.encode_text(text_inputs)
#         text_features /= text_features.norm(dim=-1, keepdim=True)

#         # Compute cosine similarity
#         similarity = (100.0 * img_features @ text_features.T).softmax(dim=-1)
#         class_idx = similarity.argmax().item()
        
#     return classes[class_idx]


# len(food_classes)

In [13]:
# import open_clip
# from PIL import Image

# # Load your fine-tuned CLIP model
# FINETUNED_MODEL_PATH = "Models/clip_finetuned_5_Shots.pth"
# # =========================================================
# # LOAD CLIP MODEL
# # =========================================================
# clip_model, _, preprocess = open_clip.create_model_and_transforms(
#     model_name="ViT-B-32",
#     pretrained="openai"
# )

# checkpoint = torch.load(FINETUNED_MODEL_PATH, map_location=DEVICE)
# clip_model.load_state_dict(checkpoint["clip_model_state_dict"])

# clip_model = clip_model.to(DEVICE).eval()
# print("✓ CLIP model loaded successfully")

# # =========================================================
# # CLIP IMAGE EMBEDDING
# # =========================================================
# def get_clip_embedding(img_path):
#     img = preprocess(Image.open(img_path).convert("RGB")) \
#             .unsqueeze(0).to(DEVICE)

#     with torch.no_grad():
#         emb = clip_model.encode_image(img)
#         emb = emb / emb.norm(dim=-1, keepdim=True)

#     return emb

# # =========================================================
# # LOAD FOOD REPRESENTATIVES (TRAIN + VALID)
# # =========================================================
# def load_food_representatives(
#     dirs=[VALID_DIR_FOOD, TRAIN_DIR_FOOD],
#     samples_per_class=5
# ):
#     reps = {}

#     for d in dirs:
#         if not os.path.exists(d):
#             continue

#         for cls in os.listdir(d):
#             cls_path = os.path.join(d, cls)

#             if not os.path.isdir(cls_path) or cls in reps:
#                 continue

#             imgs = glob(os.path.join(cls_path, "*.jpg"))[:samples_per_class]

#             if imgs:
#                 reps[cls] = [get_clip_embedding(img) for img in imgs]

#     return reps


# print("Loading food representatives...")
# food_reps = load_food_representatives(
#     dirs=[TRAIN_DIR_FOOD, VALID_DIR_FOOD],
#     samples_per_class=5
# )
# print(f"✓ Loaded {len(food_reps)} food classes")

# # =========================================================
# # FOOD RECOGNITION (IMAGE → IMAGE)
# # =========================================================
# def recognize_food_clip(img_path):
#     """
#     Image-only food recognition using CLIP representatives
#     Returns: food class (string)
#     """
#     query_emb = get_clip_embedding(img_path)

#     best_cls = None
#     best_score = -1

#     for cls, emb_list in food_reps.items():
#         for emb in emb_list:
#             score = (query_emb @ emb.T).item()
#             if score > best_score:
#                 best_score = score
#                 best_cls = cls

#     return best_cls

# **Fruit Classification**

In [14]:
# ================= FRUIT CLASSIFICATION =================
fruit_model = tf.keras.models.load_model(FRUIT_MODEL_PATH)

fruit_classes = sorted([
    d for d in os.listdir(TRAIN_DIR_FRUIT)
    if os.path.isdir(os.path.join(TRAIN_DIR_FRUIT, d))
])

def preprocess_fruit(img_path):
    img = image.load_img(img_path, target_size=(IMG_SIZE_FRUIT, IMG_SIZE_FRUIT))
    img = image.img_to_array(img) / 255.0
    img = np.expand_dims(img, axis=0)
    return img

def recognize_fruit(img_path):
    pred = fruit_model.predict(preprocess_fruit(img_path), verbose=0)
    return fruit_classes[np.argmax(pred)]


# **Calculating Calories**

In [15]:
def load_calories(file_path):
    calories = {}
    with open(file_path) as f:
        for line in f:
            m = re.match(r"(.+?):\s*~?([\d.]+)", line.strip())
            if m:
                # normalize class names: lowercase + underscores
                cls_name = m.group(1).strip().lower().replace(" ", "_")
                calories[cls_name] = float(m.group(2))
    return calories

# Merge food calories from train + valid
food_calories = {}
food_calories.update(load_calories(FOOD_CALORIES_FILE_TRAIN))
food_calories.update(load_calories(FOOD_CALORIES_FILE_VALID))

# Load fruit calories
fruit_calories = load_calories(FRUIT_CALORIES_FILE)

def extract_grams(name):
    return int(re.search(r"(\d+)g", name).group(1))



In [16]:
len(food_calories)

93

# **Binary Segmentation**

In [17]:
seg_model = tf.keras.models.load_model(SEG_MODEL_PATH, compile=False)

def run_binary_segmentation(img_path, save_dir):
    filename = os.path.basename(img_path)
    img = cv2.imread(img_path)
    if img is None:
        print(f"Error reading {filename}")
        return
    h, w = img.shape[:2]
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, SEG_IMAGE_SIZE)
    img_input = np.expand_dims(img_resized.astype("float32") / 255.0, axis=0)
    mask = np.squeeze(seg_model.predict(img_input, verbose=0))
    mask_binary = (mask > 0.5).astype(np.uint8)
    mask_resized = cv2.resize(mask_binary, (w, h), interpolation=cv2.INTER_NEAREST)
    mask_final = mask_resized * 255
    save_path = os.path.join(save_dir, os.path.splitext(filename)[0] + "_mask.png")
    cv2.imwrite(save_path, mask_final)
    print(f"Saved segmentation mask: {save_path}")

# **Multi Segmentation**

In [18]:
# ======================== IMPORTS ========================
import os
import cv2
import numpy as np
import tensorflow as tf
from skimage import measure
from scipy import stats
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.utils import register_keras_serializable
import json
import time

# ======================== CONFIG ========================
MULTI_SEG_MODEL_PATH = "Models/best_multiclass_resnet50_unet.keras"
CLASS_MAPPING_FILE = "Models/class_mapping.json"
COLOR_MAPPING_FILE = "Models/color_mapping.json"
MULTI_SEG_IMG_SIZE = 224

# Load class/color mappings
with open(CLASS_MAPPING_FILE, 'r') as f:
    class_mapping = json.load(f)
with open(COLOR_MAPPING_FILE, 'r') as f:
    color_mapping = json.load(f)
reverse_mapping = {v: k for k, v in class_mapping.items()}

# ======================== CUSTOM OBJECTS ========================
@register_keras_serializable()
class MultiClassDiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, num_classes=31, name='dice_coefficient', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.total_dice = self.add_weight(name='total_dice', initializer='zeros')
        self.count = self.add_weight(name='count', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        y_true = tf.squeeze(y_true, axis=-1)
        y_pred = tf.cast(y_pred, tf.float32)
        y_true = tf.cast(y_true, tf.float32)

        dice_sum = 0.0
        valid_classes = 0.0

        for i in range(self.num_classes):
            class_idx = tf.cast(i, tf.float32)
            y_t = tf.cast(tf.equal(y_true, class_idx), tf.float32)
            y_p = tf.cast(tf.equal(y_pred, class_idx), tf.float32)

            inter = tf.reduce_sum(y_t * y_p)
            union = tf.reduce_sum(y_t) + tf.reduce_sum(y_p)
            present = tf.cast(tf.reduce_sum(y_t) > 0, tf.float32)

            dice = tf.math.divide_no_nan(2.0 * inter, union)
            dice_sum += dice * present
            valid_classes += present

        batch_dice = tf.math.divide_no_nan(dice_sum, valid_classes)
        self.total_dice.assign_add(batch_dice)
        self.count.assign_add(1.0)

    def result(self):
        return tf.math.divide_no_nan(self.total_dice, self.count)

    def reset_state(self):
        self.total_dice.assign(0.0)
        self.count.assign(0.0)

    def get_config(self):
        config = super().get_config()
        config.update({"num_classes": self.num_classes})
        return config

@register_keras_serializable()
def combined_loss(y_true, y_pred):
    ce_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred))
    return ce_loss

def dice_loss(y_true, y_pred):
    return tf.reduce_mean(y_pred)

# ======================== HELPER FUNCTIONS ========================
def preprocess_image(img_rgb, img_size):
    img = tf.convert_to_tensor(img_rgb)
    img = tf.image.resize(img, (img_size, img_size))
    img = preprocess_input(img)
    img = tf.expand_dims(img, axis=0)
    return img

def clean_mask_advanced(pred_mask):
    if hasattr(pred_mask, 'numpy'):
        pred_mask = pred_mask.numpy()
    pred_mask = np.array(pred_mask)
    binary_mask = (pred_mask > 0).astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
    labels = measure.label(binary_mask)
    cleaned_mask = pred_mask.copy()
    for region in measure.regionprops(labels):
        if region.area < 50:
            coords = region.coords
            cleaned_mask[coords[:,0], coords[:,1]] = 0
            continue
        coords = region.coords
        region_vals = pred_mask[coords[:,0], coords[:,1]]
        region_vals = region_vals[region_vals > 0]
        if len(region_vals) > 0:
            most_frequent_class = stats.mode(region_vals, keepdims=True)[0][0]
            cleaned_mask[coords[:,0], coords[:,1]] = most_frequent_class
    return cleaned_mask

def run_multiclass_segmentation(img_path, save_dir):
    img = cv2.imread(img_path)
    if img is None:
        print(f"Could not read {img_path}")
        return
    h, w = img.shape[:2]
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    inp = preprocess_image(img_rgb, MULTI_SEG_IMG_SIZE)

    start_time = time.time()
    pred = multi_seg_model.predict(inp, verbose=0)[0]
    inference_time = (time.time() - start_time) * 1000

    raw_mask = np.argmax(pred, axis=-1)
    mask_clean = clean_mask_advanced(raw_mask)

    total_pixels = MULTI_SEG_IMG_SIZE * MULTI_SEG_IMG_SIZE
    detected_indices = np.unique(mask_clean)
    detected_indices = detected_indices[detected_indices != 0]

    for idx in detected_indices:
        pixel_count = np.sum(mask_clean == idx)
        percentage = (pixel_count / total_pixels) * 100
        if percentage <= 1.0:
            mask_clean[mask_clean == idx] = 0

# Create colored mask
    colored = np.zeros((MULTI_SEG_IMG_SIZE, MULTI_SEG_IMG_SIZE, 3), dtype=np.uint8)
    class_names_detected = []
    for idx in detected_indices:
        cls_name = reverse_mapping.get(int(idx), "background")
        rgb_color = np.array(color_mapping.get(cls_name, [0,0,0]), dtype=np.uint8)
        bgr_color = rgb_color[::-1]  # Convert RGB to BGR for OpenCV
        colored[mask_clean == idx] = bgr_color
        class_names_detected.append(cls_name)

    # Resize to original size
    colored = cv2.resize(colored, (w, h), interpolation=cv2.INTER_NEAREST)

    # Create header above the image
    header_height = 30 + 20*len(class_names_detected)
    result_img = np.zeros((h + header_height, w, 3), dtype=np.uint8)

    # Fill header with black background
    result_img[:header_height, :, :] = 0

    # Write class names
    for i, cls_name in enumerate(class_names_detected):
        cv2.putText(
            result_img,
            cls_name,
            (10, 25 + i*20),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.7,
            (255, 255, 255),
            2
        )

    # Place the colored mask below header
    result_img[header_height:, :, :] = colored

    # Save result
    base = os.path.splitext(os.path.basename(img_path))[0]
    save_path = os.path.join(save_dir, f"{base}_multiseg_mask.png")
    cv2.imwrite(save_path, result_img)

    return result_img




# ======================== LOAD MODEL ========================
multi_seg_model = tf.keras.models.load_model(
    MULTI_SEG_MODEL_PATH,
    custom_objects={
        "combined_loss": combined_loss,
        "dice_loss": dice_loss,
        "MultiClassDiceCoefficient": MultiClassDiceCoefficient
    },
    compile=False
)




# **Main Pipeline**

In [19]:
os.makedirs(OUTPUT_ROOT, exist_ok=True)

for img_name in os.listdir(TEST_IMAGES_FOLDER):
    if not img_name.lower().endswith((".jpg", ".png", ".jpeg")):
        continue

    img_path = os.path.join(TEST_IMAGES_FOLDER, img_name)
    out_dir = os.path.join(OUTPUT_ROOT, os.path.splitext(img_name)[0])
    os.makedirs(out_dir, exist_ok=True)

    main_class = predict_food_or_fruit(img_path)

    if main_class == "Food":
        sub = recognize_food(img_path)
        # sub = recognize_food_clip(img_path)
        cal = food_calories.get(sub.lower().replace(" ", "_"), 0)
    else:
        sub = recognize_fruit(img_path)
        cal = fruit_calories.get(sub.lower().replace(" ", "_"), 0)

        # Save BOTH masks for Fruit
        run_binary_segmentation(img_path, out_dir)
        run_multiclass_segmentation(img_path, out_dir)

    grams = extract_grams(img_name)
    total_cal = grams * cal

    with open(os.path.join(out_dir, "result.txt"), "w") as f:
        f.write(f"{main_class}\n{sub}\n{total_cal:.2f}\n")

    print(f"{img_name} → {main_class} | {sub} | {total_cal:.2f} kcal")


img1_75g.jpg → Food | creme_brulee | 262.50 kcal
img2_280g.jpg → Food | hot_and_sour_soup | 140.00 kcal
Saved segmentation mask: Integrated_Test_Output\img3_300g\img3_300g_mask.png
img3_300g.jpg → Fruit | Mango_Amrapali | 195.00 kcal
Saved segmentation mask: Integrated_Test_Output\img4_350g\img4_350g_mask.png
img4_350g.jpg → Fruit | Elephant Apple | 210.00 kcal
img5_700g.jpg → Food | creme_brulee | 2450.00 kcal
