In [None]:
# merged_file.py

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from glob import glob

# =========================
# PART 1: Data Loading Functions
# =========================













In [None]:


def load_calories(file_path):
    """Load calories from text file"""
    calories = {}
    try:
        with open(file_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                if ":" in line:
                    food_name, cal_info = line.split(":", 1)
                    food_name = food_name.strip()
                    cal_info = cal_info.strip()

                    if "~" in cal_info:
                        cal_part = cal_info.split("~")[1].strip()
                        cal_value_str = cal_part.split(" ")[0].strip()

                        try:
                            calories[food_name] = float(cal_value_str)
                        except:
                            continue
    except:
        print(f"Warning: Could not load calories from {file_path}")

    return calories


def normalize_label(label):
    """Normalize folder label"""
    return label.replace("_", " ").title()


def load_dataset(root_path):
    """Main loading function for the dataset"""
    root_path = os.path.normpath(root_path)

    # -------------------------------
    # FOOD
    # -------------------------------
    food_train_dir = os.path.join(root_path, "Food", "Train")
    food_val_dir   = os.path.join(root_path, "Food", "Validation")

    food_train_images, food_train_labels = [], []
    food_val_images, food_val_labels     = [], []

    # Load Food Train
    if os.path.exists(food_train_dir):
        for cat in os.listdir(food_train_dir):
            cat_path = os.path.join(food_train_dir, cat)
            if not os.path.isdir(cat_path):
                continue

            for img_path in glob(os.path.join(cat_path, "*")):
                food_train_images.append(img_path)
                food_train_labels.append(cat)

    # Load Food Val
    if os.path.exists(food_val_dir):
        for cat in os.listdir(food_val_dir):
            cat_path = os.path.join(food_val_dir, cat)
            if not os.path.isdir(cat_path):
                continue

            for img_path in glob(os.path.join(cat_path, "*")):
                food_val_images.append(img_path)
                food_val_labels.append(cat)

    # Load Food Calories
    food_train_cal = load_calories(os.path.join(root_path, "Food", "Train Calories.txt"))
    food_val_cal   = load_calories(os.path.join(root_path, "Food", "Val Calories.txt"))

    food_train_cal_normalized = {normalize_label(k): v for k, v in food_train_cal.items()}
    food_val_cal_normalized   = {normalize_label(k): v for k, v in food_val_cal.items()}

    # -------------------------------
    # FRUIT
    # -------------------------------
    fruit_train_dir = os.path.join(root_path, "Fruit", "Train")
    fruit_val_dir   = os.path.join(root_path, "Fruit", "Validation")

    fruit_train_images, fruit_train_masks, fruit_train_labels = [], [], []
    fruit_val_images, fruit_val_masks, fruit_val_labels       = [], [], []

    # Load Fruit Train
    if os.path.exists(fruit_train_dir):
        for cat in os.listdir(fruit_train_dir):
            cat_path = os.path.join(fruit_train_dir, cat)

            images_dir = os.path.join(cat_path, "Images")
            masks_dir  = os.path.join(cat_path, "Mask")

            if not os.path.isdir(images_dir):
                continue

            for img_path in glob(os.path.join(images_dir, "*")):
                name = os.path.basename(img_path).split('.')[0]
                mask_path = os.path.join(masks_dir, f"{name}_mask.png")

                fruit_train_images.append(img_path)
                fruit_train_masks.append(mask_path)
                fruit_train_labels.append(cat)

    # Load Fruit Val
    if os.path.exists(fruit_val_dir):
        for cat in os.listdir(fruit_val_dir):
            cat_path = os.path.join(fruit_val_dir, cat)

            images_dir = os.path.join(cat_path, "Images")
            masks_dir  = os.path.join(cat_path, "Mask")

            if not os.path.isdir(images_dir):
                continue

            for img_path in glob(os.path.join(images_dir, "*")):
                name = os.path.basename(img_path).split('.')[0]
                mask_path = os.path.join(masks_dir, f"{name}_mask.png")

                fruit_val_images.append(img_path)
                fruit_val_masks.append(mask_path)
                fruit_val_labels.append(cat)

    fruit_calories = load_calories(os.path.join(root_path, "Fruit", "calories.txt"))
    fruit_calories_normalized = {normalize_label(k): v for k, v in fruit_calories.items()}

    # Return Everything
    return {
        # FOOD
        "food_train_images": food_train_images,
        "food_train_labels": food_train_labels,
        "food_val_images":   food_val_images,
        "food_val_labels":   food_val_labels,
        "food_train_cal":    food_train_cal_normalized,
        "food_val_cal":      food_val_cal_normalized,

        # FRUIT
        "fruit_train_images": fruit_train_images,
        "fruit_train_masks":  fruit_train_masks,
        "fruit_train_labels": fruit_train_labels,
        "fruit_val_images":   fruit_val_images,
        "fruit_val_masks":    fruit_val_masks,
        "fruit_val_labels":   fruit_val_labels,
        "fruit_calories":     fruit_calories_normalized
    }


In [None]:

def read_images(image_paths, target_size=(224, 224)):
    """Read and resize images using OpenCV"""
    images = []
    for path in image_paths:
        # Read image in BGR format
        img = cv2.imread(path)
        if img is not None:
            # Resize to consistent size
            img = cv2.resize(img, target_size)
            # Convert BGR to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img)
        else:
            print(f"Warning: Could not read image {path}")
            # Add placeholder for missing images
            images.append(np.zeros((*target_size, 3), dtype=np.uint8))
    
    return np.array(images)


def read_masks(mask_paths, target_size=(224, 224)):
    """Read and resize masks using OpenCV"""
    masks = []
    for path in mask_paths:
        # Read mask in grayscale
        mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if mask is not None:
            # Resize to consistent size
            mask = cv2.resize(mask, target_size)
            masks.append(mask)
        else:
            print(f"Warning: Could not read mask {path}")
            # Add placeholder for missing masks
            masks.append(np.zeros(target_size, dtype=np.uint8))
    
    return np.array(masks)


def convert_rgb_to_bgr(food_data):
    """Convert all RGB images in food_data to BGR (for VGG16 compatibility)."""
    food_data["train_images"] = food_data["train_images"][..., ::-1]  # RGB → BGR
    food_data["val_images"] = food_data["val_images"][..., ::-1]      # RGB → BGR
    return food_data


def readTestData(imgPath):
    """Read test image for inference"""
    img = cv2.imread(imgPath)
    if img is not None:
        # Resize to consistent size
        target_size=(224, 224)
        img = cv2.resize(img, target_size)
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img
    return None



In [None]:
# =========================
# PART 2: U-Net Model and Training (Part D)
# =========================

class FruitBinaryDataset(Dataset):
    """Dataset for fruit binary segmentation"""
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        mask = self.masks[idx]

        # ---- Load image ----
        if isinstance(img, str):
            img = cv2.imread(img)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # ---- Load mask ----
        if isinstance(mask, str):
            mask = cv2.imread(mask, 0)

        # ---- Image preprocessing (manual, no numpy API) ----
        img = img.astype("float32") / 255.0      # HWC
        img = torch.tensor(img.tolist())         # safe
        img = img.permute(2, 0, 1)                # CHW

        # ---- Mask preprocessing ----
        mask = torch.tensor(mask.tolist(), dtype=torch.float32)

        if mask.ndim == 3:
            mask = mask[:, :, 0]

        mask = mask.unsqueeze(0)
        mask = (mask > 0).float()

        return img, mask


class DoubleConv(nn.Module):
    """Double convolution block for U-Net"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """U-Net model for binary segmentation"""
    def __init__(self):
        super().__init__()
        self.down1 = DoubleConv(3, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.middle = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(self.pool(d1))
        d3 = self.down3(self.pool(d2))

        mid = self.middle(self.pool(d3))

        u3 = self.up3(mid)
        u3 = self.conv3(torch.cat([u3, d3], dim=1))

        u2 = self.up2(u3)
        u2 = self.conv2(torch.cat([u2, d2], dim=1))

        u1 = self.up1(u2)
        u1 = self.conv1(torch.cat([u1, d1], dim=1))

        return self.final(u1)


def dice_score(pred, target, smooth=1e-6):
    """Calculate Dice score for binary segmentation"""
    pred = (pred > 0.5).float()
    inter = (pred * target).sum()
    return (2 * inter + smooth) / (pred.sum() + target.sum() + smooth)


In [None]:



def run_partD(fruit, epochs=20, batch_size=2, lr=1e-4):
    """Run Part D: Binary segmentation training"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ds = FruitBinaryDataset(
        fruit["train_images"],
        fruit["train_masks"]
    )

    val_ds = FruitBinaryDataset(
        fruit["val_images"],
        fruit["val_masks"]
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    model = UNet().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # ===== Training =====
    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)

            out = model(imgs)
            loss = criterion(out, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # ===== Validation =====
        model.eval()
        dice_total = 0

        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                out = torch.sigmoid(model(imgs))
                dice_total += dice_score(out, masks).item()

        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"Loss: {train_loss/len(train_loader):.4f} "
            f"Dice: {dice_total/len(val_loader):.4f}"
        )

    # =========================
    # Test Script
    # =========================
    def partD_test_script(image, save_path="binary_mask.png"):
        """Test function for inference"""
        model.eval()

        if isinstance(image, str):
            img = cv2.imread(image)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        img = img.astype("float32") / 255.0
        img = torch.tensor(img.tolist()).permute(2, 0, 1)
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            out = torch.sigmoid(model(img))[0, 0]

        mask = (out > 0.5).cpu().numpy() * 255
        cv2.imwrite(save_path, mask.astype("uint8"))

        return mask

    return model, partD_test_script


# =========================
# PART 3: Main Execution
# =========================

def visualize_samples(food, fruit):
    """Visualize sample images and masks"""
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    
    # Food samples
    for i in range(3):
        axes[0, i].imshow(food['train_images'][i])
        axes[0, i].set_title(f"Food: {food['train_labels'][i]}")
        axes[0, i].axis('off')
    
    # Fruit samples with masks
    for i in range(3):
        axes[1, i].imshow(fruit['train_images'][i])
        axes[1, i].set_title(f"Fruit: {fruit['train_labels'][i]}")
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show masks separately
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    for i in range(3):
        axes[i].imshow(fruit['train_masks'][i], cmap='gray')
        axes[i].set_title(f"Mask: {fruit['train_labels'][i]}")
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()


def main():
    """Main execution function"""
    # Load dataset paths
    root_path = r"Project Data"
    data = load_dataset(root_path)
    
    # Read actual images using OpenCV
    print("Loading Food images...")
    food_train_images = read_images(data["food_train_images"])
    food_val_images = read_images(data["food_val_images"])

    print("Loading Fruit images...")
    fruit_train_images = read_images(data["fruit_train_images"])
    fruit_val_images = read_images(data["fruit_val_images"])

    print("Loading Fruit masks...")
    fruit_train_masks = read_masks(data["fruit_train_masks"])
    fruit_val_masks = read_masks(data["fruit_val_masks"])

    # Prepare the data dictionaries
    food = {
        "train_images": food_train_images,
        "train_labels": data["food_train_labels"],
        "val_images": food_val_images,
        "val_labels": data["food_val_labels"],
        "train_cal": data["food_train_cal"],
        "val_cal": data["food_val_cal"]
    }

    fruit = {
        "train_images": fruit_train_images,
        "train_masks": fruit_train_masks,
        "train_labels": data["fruit_train_labels"],
        "val_images": fruit_val_images,
        "val_masks": fruit_val_masks,
        "val_labels": data["fruit_val_labels"],
        "calories": data["fruit_calories"]
    }

    # Data inspection
    print("\n=== Data Summary ===")
    print(f"Food - Training: {len(food['train_images'])} images, {len(set(food['train_labels']))} categories")
    print(f"Food - Validation: {len(food['val_images'])} images")
    print(f"Fruit - Training: {len(fruit['train_images'])} images, {len(set(fruit['train_labels']))} categories")
    print(f"Fruit - Validation: {len(fruit['val_images'])} images")
    print(f"Fruit masks - Training: {len(fruit['train_masks'])} masks")
    print(f"Fruit masks - Validation: {len(fruit['val_masks'])} masks")

    # Check image shapes
    print(f"\nImage shape: {food['train_images'][0].shape}")
    print(f"Mask shape: {fruit['train_masks'][0].shape}")

    print("\nData loading completed successfully!")

    # Test image paths
    imgPaths = [
        os.path.join("Project Data", "Food", "Validation", "ceviche", "217909.jpg"),
        os.path.join("Project Data", "Lichi.jpg"),
        os.path.join("Project Data", "Guava.jpg"),
        os.path.join("Project Data", "persimmons.jpg"),
        os.path.join("Project Data", "mango.jpg"),
        os.path.join("Project Data", "Fruit", "Validation", "Banana", "Images", "76.jpg")
    ]

    # Example usage (you can uncomment and modify as needed)
    # For fruit classification and segmentation:
    if 'Fruit' == 'Fruit':  # This always evaluates to True, but kept for structure
        print("it's Fruit")
        print("in Masks")
        
        # Train segmentation model
        seg_model, partD_test = run_partD(
            fruit,
            epochs=10,
            batch_size=4
        )

        # Test on an image
        mask = partD_test(imgPaths[5], "output_mask_partD.png")
        print(f"Mask saved to output_mask_partD.png")
    
    # For food classification:
    else:
        print("it's Food")
        food_BGR = convert_rgb_to_bgr(food)
        # Note: partB functions are not included in this merge
        # You would need to add partB code here if needed


if __name__ == "__main__":
    print("\n" + "="*60)
    print("STARTING APPLICATION")
    print("="*60)
    main()