In [2]:
 import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np

def is_red(p): 
    return p[0] > 100 and p[1] < 100 and p[2] < 100

def is_white(p): 
    return p[0] > 200 and p[1] > 200 and p[2] > 200

def compute_mean_position(a, f): 
    y_positions = []
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            if f(a[i, j]):
                y_positions.append(i)
    if len(y_positions) > 0:
        return np.mean(y_positions)
    else:
        return None

def determine_flag_in_area(image_array):
    r = compute_mean_position(image_array, is_red)
    w = compute_mean_position(image_array, is_white)
    print("Mean Red Position:", r, "Mean White Position:", w)  # Debug print

    if r is None or w is None:
        return "No clear flag detected in the image."
    if r < w:
        return "The flag is Indonesia."
    elif w < r:
        return "The flag is Poland."
    else:
        return "The image does not match the Indonesia or Poland flag pattern."

def main():
    # Load SSD model
    model = models.detection.ssd300_vgg16(pretrained=True).eval()

    # Transform for input image
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Load the image
    image_path = 'test_image.png'
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0)

    # Perform detection
    detections = model(image_tensor)[0]

    # Threshold for confidence
    confidence_threshold = 0.2

    # Convert image to numpy array
    image_array = np.array(image)

    # Iterate through detections
    for i in range(len(detections['boxes'])):
        score = detections['scores'][i].item()
        label = detections['labels'][i].item()

        # Assuming "flag" is mapped to a specific label ID (custom training required)
        if score > confidence_threshold:  # You can add `label == FLAG_LABEL_ID` for custom models
            # Extract bounding box
            box = detections['boxes'][i].detach().cpu().numpy().astype(int)
            x1, y1, x2, y2 = box

            # Crop the detected region (flag)
            flag_region = image_array[y1:y2, x1:x2]

            # Determine flag type
            flag_type = determine_flag_in_area(flag_region)
            print(f"Flag detected: {flag_type}")
            break
    else:
        print("No flag detected.")

if __name__ == "__main__":
    main()


Mean Red Position: 27.870189431704887 Mean White Position: 74.39209621993128
Flag detected: The flag is Indonesia.
