In [74]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from torchvision import datasets 
import torch
import cv2
import os
from PIL import Image
from ultralytics import YOLO
import pandas as pd
import json

In [75]:
# --- Configuration ---
DETECTOR_MODEL_PATH = 'runs/detect/yolov8n-detector/weights/best.pt'
PROTO_MODEL_PATH = "prototypical_networks_2.pth"
PROTOTYPES_PATH = "prototypes_resnet.pth"
INDEX_TO_LABEL_CSV = "index_to_label.csv"
OUTPUT_FOLDER = "detected_characters"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

In [76]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [77]:
detector = YOLO(DETECTOR_MODEL_PATH)

In [78]:
# --- Model Definitions ---
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(self, support_images, support_labels, query_images):
        z_support = self.backbone(support_images)
        z_query = self.backbone(query_images)
        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat([z_support[torch.nonzero(support_labels == label)].mean(0) for label in range(n_way)])
        dists = torch.cdist(z_query, z_proto)
        return -dists

In [79]:
# 2. Load Prototypical Network
def load_models():
    backbone = models.resnet18(pretrained=True)
    backbone.fc = nn.Flatten()
    model = PrototypicalNetworks(backbone).to(device)
    model.load_state_dict(torch.load(PROTO_MODEL_PATH, map_location=device))
    prototypes = torch.load(PROTOTYPES_PATH, map_location=device)
    return model, prototypes

model, prototypes = load_models()

  model.load_state_dict(torch.load(PROTO_MODEL_PATH, map_location=device))
  prototypes = torch.load(PROTOTYPES_PATH, map_location=device)


In [80]:
from PIL import Image
from torchvision import transforms

# Define the same preprocessing used during training/support
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # convert to 3 channels
    transforms.Resize((28, 28)),  # or your backbone's expected size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])




In [81]:
# --- Handle 28x28 input for standard ResNet ---
def adapt_resnet_for_28x28():
    # Modify first layer to handle small inputs
    model.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.backbone.maxpool = nn.Identity()  # Remove first maxpool
    model.backbone = model.backbone.to(device)

adapt_resnet_for_28x28()

In [82]:
def process_image(image_path):
    # 1. Run YOLOv8 Detection
    detector = YOLO(DETECTOR_MODEL_PATH)
    results = detector(image_path)
    
    output_img = results[0].orig_img.copy()
    height, width = output_img.shape[:2]
    
    predictions = []
    for i, box in enumerate(results[0].boxes):
        x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
        confidence = box.conf.item()
        
        if x2 <= x1 or y2 <= y1 or (x2-x1) < 5 or (y2-y1) < 5:
            continue
            
        try:
            crop = output_img[max(0,y1-2):min(height,y2+2), max(0,x1-2):min(width,x2+2)]
            crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
            crop_tensor = preprocess(crop_pil).unsqueeze(0).to(device)
            
            with torch.no_grad():
                embedding = model.backbone(crop_tensor)
                if embedding.dim() == 1:
                    embedding = embedding.unsqueeze(0)
                dists = torch.cdist(embedding, prototypes)
                class_idx = torch.argmax(-dists).item()
            
            df = pd.read_csv(INDEX_TO_LABEL_CSV)
            label = df[df['Index'] == class_idx]['Label Name'].values[0]
            
            predictions.append({
                "bbox": [x1, y1, x2, y2],
                "class": label,
                "confidence": confidence
            })
            
            # Visualization
            cv2.rectangle(output_img, (x1, y1), (x2, y2), (0,255,0), 2)
            label_text = f"{label} {confidence:.2f}"
            (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 2, 4 )
            # Adjust label background rectangle
            cv2.rectangle(output_img, 
                        (x1, y1 - h - 10),  # Increased padding
                        (x1 + w, y1),
                        (0, 255, 0), -1)

            # Draw text with larger font
            cv2.putText(output_img, label_text,
                    (x1, y1 - 5),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    2,  
                    (0, 0, 255), 
                    4)  
            
        except Exception as e:
            print(f"Error processing detection {i}: {str(e)}")
            continue
    
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    output_img_path = os.path.join(OUTPUT_FOLDER, f"{base_name}_annotated.jpg")
    cv2.imwrite(output_img_path, output_img)
    
    with open(os.path.join(OUTPUT_FOLDER, f"{base_name}_preds.json"), 'w') as f:
        json.dump(predictions, f, indent=2)
    
    return predictions, output_img_path


In [83]:
if __name__ == "__main__":
    image_path = "test_3.jpg"
    predictions, output_path = process_image(image_path)
    print(f"\nResults saved to: {output_path}")
    print("\nDetected Characters:")
    for i, pred in enumerate(predictions):
        print(f"{i+1}. {pred['class']} (Confidence: {pred['confidence']:.2f}) at {pred['bbox']}")



image 1/1 c:\Users\acer\Desktop\prototypical_nets\test_3.jpg: 320x256 8 items, 15.1ms
Speed: 2.2ms preprocess, 15.1ms inference, 2.4ms postprocess per image at shape (1, 3, 320, 256)

Results saved to: detected_characters\test_3_annotated.jpg

Detected Characters:
1. maithili_dhaa (Confidence: 0.78) at [356, 2901, 986, 3669]
2. maithili_dhaa (Confidence: 0.78) at [1940, 1978, 2514, 2623]
3. maithili_dhaa (Confidence: 0.74) at [362, 2105, 949, 2780]
4. tibetan-gha (Confidence: 0.70) at [1782, 2976, 2371, 3730]
5. maithili_dhaa (Confidence: 0.68) at [1036, 1497, 1461, 1998]
6. maithili_dhaa (Confidence: 0.61) at [355, 755, 923, 1412]
7. maithili_dhaa (Confidence: 0.52) at [1802, 137, 2373, 953]
8. maithili_dhaa (Confidence: 0.38) at [1989, 1140, 2676, 1898]


In [84]:
# --- Fixed Functions ---
def prepare_query_image(image_path):
    """Convert image to preprocessed tensor with batch dim"""
    img = Image.open(image_path).convert("RGB")
    img_tensor = preprocess(img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]
    return img_tensor

def infer_single_query(backbone, prototypes, query_image_tensor):
    """Classify a single query image"""
    backbone.eval()
    with torch.no_grad():
        z_query = backbone(query_image_tensor)  # [1, embedding_dim]
        dists = torch.cdist(z_query, prototypes)  # [1, n_way]
        pred = torch.argmax(-dists, dim=1).item()
    return pred

In [85]:
# --- Main Pipeline ---
def process_image(image_path):
    img = cv2.imread(image_path)
    results = detector(img)
    predictions = []
    
    for i, box in enumerate(results[0].boxes.xyxy.cpu().numpy()):
        x1, y1, x2, y2 = map(int, box)
        crop = img[y1:y2, x1:x2]
        crop_path = os.path.join(OUTPUT_FOLDER, f'char_{i}.png')
        cv2.imwrite(crop_path, crop)
        
        try:
            # Fixed preprocessing and inference
            crop_tensor = prepare_query_image(crop_path)
            class_idx = infer_single_query(model.backbone, prototypes, crop_tensor)
            
            df = pd.read_csv(INDEX_TO_LABEL_CSV)
            label = df[df['Index'] == class_idx]['Label Name'].values[0]
            
            predictions.append({
                "bbox": [x1, y1, x2, y2],
                "class": label,
                "confidence": box.conf.item()
            })
            
            # Draw results
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(img, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
        except Exception as e:
            print(f"Error processing crop {i}: {e}")
    
    output_path = os.path.join(OUTPUT_FOLDER, "annotated.jpg")
    cv2.imwrite(output_path, img)
    return predictions

In [86]:
# --- Execution ---
if __name__ == "__main__":
    predictions = process_image("test_3.jpg")
    for pred in predictions:
        print(f"{pred['class']}: {pred['bbox']} (conf: {pred['confidence']:.2f})")


0: 320x256 8 items, 11.3ms
Speed: 1.1ms preprocess, 11.3ms inference, 1.9ms postprocess per image at shape (1, 3, 320, 256)
Error processing crop 0: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 1: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 2: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 3: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 4: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 5: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 6: 'numpy.ndarray' object has no attribute 'conf'
Error processing crop 7: 'numpy.ndarray' object has no attribute 'conf'
