In [35]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from torchvision import datasets 
import torch
import cv2
import os
from PIL import Image
from ultralytics import YOLO
import pandas as pd


In [36]:
# --- Configuration ---
DETECTOR_MODEL_PATH = 'runs/detect/yolov8n-detector/weights/best.pt'
PROTO_MODEL_PATH = "prototypical_networks_2.pth"
PROTOTYPES_PATH = "prototypes_resnet.pth"
INDEX_TO_LABEL_CSV = "index_to_label.csv"
OUTPUT_FOLDER = "detected_characters"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
detector = YOLO(DETECTOR_MODEL_PATH)

In [39]:
# --- Model Definitions ---
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(self, support_images, support_labels, query_images):
        z_support = self.backbone(support_images)
        z_query = self.backbone(query_images)
        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat([z_support[torch.nonzero(support_labels == label)].mean(0) for label in range(n_way)])
        dists = torch.cdist(z_query, z_proto)
        return -dists

In [40]:
# 2. Load Prototypical Network
def load_models():
    backbone = models.resnet18(pretrained=True)
    backbone.fc = nn.Flatten()
    model = PrototypicalNetworks(backbone).to(device)
    model.load_state_dict(torch.load(PROTO_MODEL_PATH, map_location=device))
    prototypes = torch.load(PROTOTYPES_PATH, map_location=device)
    return model, prototypes

model, prototypes = load_models()

  model.load_state_dict(torch.load(PROTO_MODEL_PATH, map_location=device))
  prototypes = torch.load(PROTOTYPES_PATH, map_location=device)


In [41]:
from PIL import Image
from torchvision import transforms

# Define the same preprocessing used during training/support
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # convert to 3 channels
    transforms.Resize((28, 28)),  # or your backbone's expected size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])




In [42]:
def process_image(image_path):
    # 1. Run YOLOv8 Detection
    detector = YOLO(DETECTOR_MODEL_PATH)
    results = detector(image_path)
    
    # Create a copy of the image for drawing
    output_img = results[0].orig_img.copy()
    
    # Get detection results
    boxes = results[0].boxes
    
    # 2. Process each detection
    predictions = []
    for i, box in enumerate(boxes):
        # Extract box info
        x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
        confidence = box.conf.item()
        
        # Crop character region
        crop = output_img[y1:y2, x1:x2]
        
        try:
            # Preprocess and classify
            crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
            crop_tensor = preprocess(crop_pil).unsqueeze(0).to(device)
            
            with torch.no_grad():
                embedding = model.backbone(crop_tensor)
                dists = torch.cdist(embedding, prototypes)
                class_idx = torch.argmax(-dists).item()
            
            # Get class label
            df = pd.read_csv(INDEX_TO_LABEL_CSV)
            label = df[df['Index'] == class_idx]['Label Name'].values[0]
            
            # Store prediction
            predictions.append({
                "bbox": [x1, y1, x2, y2],
                "class": label,
                "confidence": confidence
            })
            
            # Draw bounding box and label
            box_color = (0, 255, 0)  # Green
            text_color = (0, 0, 255)  # Red
            
            # Draw bounding box
            cv2.rectangle(output_img, (x1, y1), (x2, y2), box_color, 2)
            
            # Draw label background
            label_text = f"{label} {confidence:.2f}"
            (text_width, text_height), _ = cv2.getTextSize(label_text, 
                                                         cv2.FONT_HERSHEY_SIMPLEX, 
                                                         0.7, 2)
            cv2.rectangle(output_img, 
                         (x1, y1 - text_height - 10),
                         (x1 + text_width, y1),
                         box_color, -1)  # Filled rectangle
            
            # Put text
            cv2.putText(output_img, label_text,
                       (x1, y1 - 5), 
                       cv2.FONT_HERSHEY_SIMPLEX, 
                       0.7, text_color, 2)
            
        except Exception as e:
            print(f"Error processing detection {i}: {str(e)}")
            continue
    
    # Save output image
    output_path = os.path.join(OUTPUT_FOLDER, os.path.basename(image_path))
    cv2.imwrite(output_path, output_img)
    
    # Save predictions to JSON
    json_path = os.path.join(OUTPUT_FOLDER, f"{os.path.splitext(os.path.basename(image_path))[0]}_preds.json")
    with open(json_path, 'w') as f:
        json.dump(predictions, f, indent=2)
    
    return predictions, output_path


In [43]:

# --- Run Pipeline ---
if __name__ == "__main__":
    image_path = "test_3.jpg"
    predictions, output_img_path = process_image(image_path)
    
    print(f"\nDetection results saved to: {output_img_path}")
    print("\nPredicted Characters:")
    for i, pred in enumerate(predictions):
        print(f"{i+1}. {pred['class']} (Confidence: {pred['confidence']:.2f}) at position {pred['bbox']}")


image 1/1 c:\Users\acer\Desktop\prototypical_nets\test_3.jpg: 320x256 8 items, 15.2ms
Speed: 2.5ms preprocess, 15.2ms inference, 4.6ms postprocess per image at shape (1, 3, 320, 256)
Error processing detection 0: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 1: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 2: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 3: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 4: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 5: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
Error processing detection 6: Expected more than 1 value per

NameError: name 'json' is not defined

In [None]:
# --- Fixed Functions ---
def prepare_query_image(image_path):
    """Convert image to preprocessed tensor with batch dim"""
    img = Image.open(image_path).convert("RGB")
    img_tensor = preprocess(img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]
    return img_tensor

def infer_single_query(backbone, prototypes, query_image_tensor):
    """Classify a single query image"""
    backbone.eval()
    with torch.no_grad():
        z_query = backbone(query_image_tensor)  # [1, embedding_dim]
        dists = torch.cdist(z_query, prototypes)  # [1, n_way]
        pred = torch.argmax(-dists, dim=1).item()
    return pred

In [None]:
# --- Main Pipeline ---
def process_image(image_path):
    img = cv2.imread(image_path)
    results = detector(img)
    predictions = []
    
    for i, box in enumerate(results[0].boxes.xyxy.cpu().numpy()):
        x1, y1, x2, y2 = map(int, box)
        crop = img[y1:y2, x1:x2]
        crop_path = os.path.join(OUTPUT_FOLDER, f'char_{i}.png')
        cv2.imwrite(crop_path, crop)
        
        try:
            # Fixed preprocessing and inference
            crop_tensor = prepare_query_image(crop_path)
            class_idx = infer_single_query(model.backbone, prototypes, crop_tensor)
            
            df = pd.read_csv(INDEX_TO_LABEL_CSV)
            label = df[df['Index'] == class_idx]['Label Name'].values[0]
            
            predictions.append({
                "bbox": [x1, y1, x2, y2],
                "class": label,
                "confidence": box.conf.item()
            })
            
            # Draw results
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(img, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
        except Exception as e:
            print(f"Error processing crop {i}: {e}")
    
    output_path = os.path.join(OUTPUT_FOLDER, "annotated.jpg")
    cv2.imwrite(output_path, img)
    return predictions

In [None]:
# --- Execution ---
if __name__ == "__main__":
    predictions = process_image("test_3.jpg")
    for pred in predictions:
        print(f"{pred['class']}: {pred['bbox']} (conf: {pred['confidence']:.2f})")