# Pill Identification Pipeline
Implementation based on: "An Accurate Deep Learning-Based System for Automatic Pill Identification: Model Development and Validation" (Heo et al., 2023)

Pipeline: **Raw Image → YOLO (Imprint Detection) → ResNet (Feature Recognition) → RNN (Imprint Correction) → Database Retrieval**

## 1. Imports and Setup

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from ultralytics import YOLO
from PIL import Image
import pandas as pd
import numpy as np
from pathlib import Path
from collections import Counter
import sys

# Add parent directory to path for imports
sys.path.append('..')

from models.rnn_model import Seq2SeqWithAttention
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches

device = torch.device(
    'mps' if torch.backends.mps.is_available() else
    'cuda' if torch.cuda.is_available() else 'cpu'
)
print(f'Using device: {device}')

## 2. Define ResNet Model

In [2]:
class MultiTaskResNet(nn.Module):
    """ResNet-18 with multitask heads for shape, color, and form"""
    
    def __init__(self, num_shapes, num_colors, num_forms):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.shape_head = nn.Linear(in_features, num_shapes)
        self.color_head = nn.Linear(in_features, num_colors)
        self.form_head = nn.Linear(in_features, num_forms)
    
    def forward(self, x):
        features = self.backbone(x)
        return self.shape_head(features), self.color_head(features), self.form_head(features)

## 3. Load All Models

In [None]:
print("Loading YOLO model...")
yolo_model = YOLO('../models/yolo/best.pt')

print("Loading ResNet model...")
resnet_checkpoint = torch.load('../models/resnet/pill_classifier_full.pth', weights_only=False)
resnet_model = MultiTaskResNet(
    resnet_checkpoint['num_shape_classes'],
    resnet_checkpoint['num_color_classes'],
    resnet_checkpoint['num_form_classes']
).to(device)
resnet_model.load_state_dict(resnet_checkpoint['model_state_dict'])
resnet_model.eval()

shape_encoder = resnet_checkpoint['shape_encoder']
color_encoder = resnet_checkpoint['color_encoder']
form_encoder = resnet_checkpoint['form_encoder']

resnet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

print("Loading RNN dataset info...")
rnn_dataset = torch.load('../data/predictions/rnn_dataset.pt', weights_only=False)

char_to_idx = rnn_dataset['char_to_idx']
idx_to_char = rnn_dataset['idx_to_char']
EOS_IDX = rnn_dataset['EOS_IDX']
char_encoder = rnn_dataset['char_encoder']
shape_encoder_rnn = rnn_dataset['shape_encoder']
color_encoder_rnn = rnn_dataset['color_encoder']
form_encoder_rnn = rnn_dataset['form_encoder']

print("Loading RNN model...")
rnn_checkpoint = torch.load('../models/rnn/rnn_imprint_correction_final.pt', weights_only=False)

rnn_model = Seq2SeqWithAttention(
    input_dim=rnn_dataset['feature_dim'],
    vocab_size=len(rnn_dataset['ALL_CHARS']),
    embedding_dim=45,
    hidden_dim=256,
    dropout=0.1,
    sos_idx=rnn_dataset['SOS_IDX'],
    eos_idx=rnn_dataset['EOS_IDX']
).to(device)
rnn_model.load_state_dict(rnn_checkpoint['model_state_dict'])
rnn_model.eval()

print("Loading database...")
database = pd.read_csv('../data/splits/pillbox_full.csv')

print(f"\n✓ All models loaded!")
print(f"  - YOLO: Imprint detection")
print(f"  - ResNet: {len(shape_encoder.classes_)} shapes, {len(color_encoder.classes_)} colors, {len(form_encoder.classes_)} forms")
print(f"  - RNN: Imprint correction (input_dim={rnn_dataset['feature_dim']})")
print(f"  - Database: {len(database)} pills")

## 4. Helper Functions

In [4]:
def sort_boxes_left_to_right(boxes):
    """Sort bounding boxes left-to-right, top-to-bottom"""
    if len(boxes) == 0:
        return []
    centers = [(box['bbox'][0], box['bbox'][1]) for box in boxes]
    sorted_indices = sorted(range(len(centers)), key=lambda i: (round(centers[i][1] * 10), centers[i][0]))
    return [boxes[i] for i in sorted_indices]


def encode_features_for_rnn(yolo_detections, shape, color, form):
    """Encode YOLO detections + ResNet features for RNN input
    Format: [x, y, char_OHE, shape_OHE, color_OHE, form_OHE] per character"""
    sorted_boxes = sort_boxes_left_to_right(yolo_detections)
    
    shape_ohe = shape_encoder_rnn.transform([[shape]])[0]
    color_ohe = color_encoder_rnn.transform([[color]])[0]
    form_ohe = form_encoder_rnn.transform([[form]])[0]
    
    sequences = []
    for det in sorted_boxes:
        x_center, y_center, w, h = det['bbox']
        char = det['class_name'].upper()
        char_ohe = char_encoder.transform([[char]])[0]
        
        feature_vector = np.concatenate([[x_center, y_center], char_ohe, shape_ohe, color_ohe, form_ohe])
        sequences.append(feature_vector)
    
    return np.array(sequences) if sequences else np.array([])


def create_mask(X):
    """Create mask for valid (non-padded) positions"""
    return (X.sum(dim=2) != 0).float()


def decode_sequence(indices):
    """Convert token indices to string, stopping at EOS"""
    chars = []
    for idx in indices:
        idx_val = idx.item() if torch.is_tensor(idx) else idx
        if idx_val == EOS_IDX:
            break
        if idx_val in idx_to_char:
            char = idx_to_char[idx_val]
            if char not in ['<SOS>', '<PAD>', '<EOS>']:
                chars.append(char)
    return ''.join(chars)


def levenshtein_distance(s1, s2):
    """Calculate edit distance between two strings"""
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)
    if len(s2) == 0:
        return len(s1)
    
    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    return previous_row[-1]


def count_overlapping_chars(s1, s2):
    """Count overlapping characters between two strings"""
    counter1 = Counter(s1.lower())
    counter2 = Counter(s2.lower())
    return sum((counter1 & counter2).values())


def calculate_similarity(prediction, db_pill):
    """Calculate similarity score: shape/color/form/imprint matching"""
    shape_score = 1/4 if prediction['shape'] == db_pill['shape'] else 0
    color_score = 1/4 if prediction['color'] == db_pill['color'] else 0
    form_score = 1/4 if prediction['form'] == db_pill['form'] else 0
    
    pred_text = str(prediction['imprint']).upper()
    db_text = str(db_pill['imprint']).upper()
    
    edit_dist = levenshtein_distance(pred_text, db_text)
    total_len = max(len(pred_text), len(db_text))
    edit_score = (1 - edit_dist / total_len) if total_len > 0 else 0
    
    overlap = count_overlapping_chars(pred_text, db_text)
    overlap_score = (overlap * 2) / (len(pred_text) + len(db_text)) if (len(pred_text) + len(db_text)) > 0 else 0
    
    return shape_score + color_score + form_score + edit_score + overlap_score


print("✓ Helper functions defined")

✓ Helper functions defined


## 5. Main Pipeline Function

In [5]:
def predict_pill(image_path):
    """Run complete pipeline: YOLO -> ResNet -> RNN"""
    image_path = Path(image_path)
    image = Image.open(image_path).convert('RGB')
    
    yolo_results = yolo_model.predict(image_path, conf=0.15, agnostic_nms=True, verbose=False)[0]
    detections = []
    for box in yolo_results.boxes:
        detections.append({
            'class_id': int(box.cls),
            'class_name': yolo_results.names[int(box.cls)].upper(),
            'confidence': float(box.conf),
            'bbox': box.xywhn.tolist()[0]
        })
    
    image_tensor = resnet_transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        shape_out, color_out, form_out = resnet_model(image_tensor)
        pred_shape_idx = shape_out.argmax(dim=1).item()
        pred_color_idx = color_out.argmax(dim=1).item()
        pred_form_idx = form_out.argmax(dim=1).item()
    
    pred_shape = shape_encoder.inverse_transform([pred_shape_idx])[0]
    pred_color = color_encoder.inverse_transform([pred_color_idx])[0]
    pred_form = form_encoder.inverse_transform([pred_form_idx])[0]
    
    corrected_imprint = ""
    if detections:
        features = encode_features_for_rnn(detections, pred_shape, pred_color, pred_form)
        if len(features) > 0:
            X = torch.FloatTensor(features).unsqueeze(0).to(device)
            max_len = 48
            if X.shape[1] < max_len:
                padding = torch.zeros(1, max_len - X.shape[1], X.shape[2]).to(device)
                X = torch.cat([X, padding], dim=1)
            
            src_mask = create_mask(X).to(device)
            with torch.no_grad():
                predictions, _, lengths = rnn_model.predict(X, max_len=50, src_mask=src_mask)
            corrected_imprint = decode_sequence(predictions[0][:lengths[0]].cpu())
    
    return {
        'shape': pred_shape,
        'color': pred_color,
        'form': pred_form,
        'imprint': corrected_imprint,
        'yolo_detections': detections
    }


def retrieve_top_k(prediction, k=3):
    """Retrieve top-k most similar pills from database"""
    scores = []
    for idx in range(len(database)):
        pill = database.iloc[idx]
        score = calculate_similarity(prediction, {
            'shape': pill['splshape_text'],
            'color': pill['splcolor_text'],
            'form': pill['dosage_form'],
            'imprint': pill['splimprint_clean']
        })
        scores.append((pill, score))
    
    scores.sort(key=lambda x: x[1], reverse=True)
    return scores[:k]


print("✓ Pipeline functions defined")

✓ Pipeline functions defined


## 6. Load Test Images

In [None]:
# Load test images from final_test.csv
test_df = pd.read_csv('../data/splits/final_test.csv')

# Remove underscores from splimprint_clean and save to 'label' column
test_df['label'] = test_df['splimprint_clean'].str.replace('_', '', regex=False)

test_images = test_df['original_name'].tolist()
image_dir = Path('../data/pillbox_production_images_full_202008')

print(f"Loaded {len(test_images)} test images from final_test.csv")
print(f"First 5 images: {test_images[:5]}")
print(f"\nGround truth labels (first 5):")
for i in range(5):
    print(f"  {test_images[i]}: {test_df.iloc[i]['label']}")

## 7. Run Pipeline on Test Images

In [7]:
results = []

print(f"\n{'='*80}")
print(f"RUNNING PIPELINE ON {len(test_images)} TEST IMAGES")
print(f"{'='*80}\n")

for img_name in tqdm(test_images):
    img_path = image_dir / img_name
    if not img_path.exists():
        continue
    
    prediction = predict_pill(img_path)
    top_3 = retrieve_top_k(prediction, k=3)
    
    results.append({
        'image_name': img_name,
        'pred_shape': prediction['shape'],
        'pred_color': prediction['color'],
        'pred_form': prediction['form'],
        'pred_imprint': prediction['imprint'],
        'top1_medicine': top_3[0][0]['medicine_name'],
        'top1_score': top_3[0][1],
        'top2_medicine': top_3[1][0]['medicine_name'],
        'top2_score': top_3[1][1],
        'top3_medicine': top_3[2][0]['medicine_name'],
        'top3_score': top_3[2][1]
    })

results_df = pd.DataFrame(results)
print(f"\n✓ Pipeline completed on {len(results_df)} images")


RUNNING PIPELINE ON 797 TEST IMAGES



100%|██████████| 797/797 [04:52<00:00,  2.73it/s]


✓ Pipeline completed on 797 images





## 8. Display Results

In [8]:
print("\n" + "="*80)
print("TOP 3 PREDICTIONS FOR EACH IMAGE")
print("="*80 + "\n")

for i in range(min(10, len(results_df))):
    row = results_df.iloc[i]
    print(f"{i+1}. {row['image_name']}")
    print(f"   Predicted: {row['pred_shape']}, {row['pred_color']}, {row['pred_form']}, '{row['pred_imprint']}'")
    print(f"   Top 3 matches:")
    print(f"      1. {row['top1_medicine'][:50]:50s} (score: {row['top1_score']:.3f})")
    print(f"      2. {row['top2_medicine'][:50]:50s} (score: {row['top2_score']:.3f})")
    print(f"      3. {row['top3_medicine'][:50]:50s} (score: {row['top3_score']:.3f})")
    print()


TOP 3 PREDICTIONS FOR EACH IMAGE

1. 003782722.jpg
   Predicted: ROUND, BLUE, C42998, 'AMG212'
   Top 3 matches:
      1. Oxycodone Hydrochloride                            (score: 1.850)
      2. Hyoscyamine Sulfate                                (score: 1.750)
      3. Levsin                                             (score: 1.750)

2. 625590149.jpg
   Predicted: CAPSULE, GREEN, C25158, '104'
   Top 3 matches:
      1. bethanechol chloride                               (score: 2.000)
      2. Sulfasalazine                                      (score: 2.000)
      3. dantrolene sodium                                  (score: 1.833)

3. 51862-0172-12_NLMIMAGE10_55462AE1.jpg
   Predicted: ROUND, PINK, C42998, '172'
   Top 3 matches:
      1. Fluoride                                           (score: 2.250)
      2. Fanapt                                             (score: 1.967)
      3. FANAPT                                             (score: 1.967)

4. CBR07490.jpg
   Predicted:

## 9. Visualize Example Predictions

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for idx, ax in enumerate(axes):
    img_name = results_df.iloc[idx]['image_name']
    img_path = image_dir / img_name
    image = Image.open(img_path)
    prediction = predict_pill(img_path)
    
    ax.imshow(image)
    
    img_width, img_height = image.size
    for det in prediction['yolo_detections']:
        x_center, y_center, w, h = det['bbox']
        x = (x_center - w/2) * img_width
        y = (y_center - h/2) * img_height
        
        rect = patches.Rectangle((x, y), w * img_width, h * img_height, 
                                linewidth=2, edgecolor='red', facecolor='none')
        ax.add_patch(rect)
        ax.text(x, y - 5, f"{det['class_name']} ({det['confidence']:.2f})",
                color='red', fontsize=10, fontweight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
    
    ax.axis('off')
    ax.set_title(f"{prediction['shape']}, {prediction['color']}, {prediction['form']}\n"
                 f"Imprint: '{prediction['imprint']}'\n"
                 f"Top-1: {results_df.iloc[idx]['top1_medicine'][:30]}", fontsize=10)

plt.tight_layout()
plt.savefig('../validation_preview.jpg', dpi=150, bbox_inches='tight')
plt.show()
print("✓ Visualization saved")

## 10. Calculate Accuracy (Top-1 and Top-3)

In [10]:
# Load ground truth labels from test_df
ground_truth_lookup = dict(zip(test_df['original_name'], test_df['label']))

# Add ground truth labels to results
results_df['true_label'] = results_df['image_name'].map(ground_truth_lookup)

# Calculate imprint accuracy (exact match on imprint)
top1_correct = (results_df['true_label'] == results_df['pred_imprint']).sum()
top1_accuracy = top1_correct / len(results_df)

# Calculate medicine name accuracy from database
database['label'] = database['splimprint_clean'].str.replace('_', '', regex=False)
medicine_lookup = dict(zip(database['original_name'], database['medicine_name']))
results_df['true_medicine'] = results_df['image_name'].map(medicine_lookup)

# Medicine Top-1 Accuracy
med_top1_correct = (results_df['true_medicine'] == results_df['top1_medicine']).sum()
med_top1_accuracy = med_top1_correct / len(results_df)

# Medicine Top-3 Accuracy
med_top3_correct = (
    (results_df['true_medicine'] == results_df['top1_medicine']) |
    (results_df['true_medicine'] == results_df['top2_medicine']) |
    (results_df['true_medicine'] == results_df['top3_medicine'])
).sum()
med_top3_accuracy = med_top3_correct / len(results_df)

# Display accuracy results
print("\n" + "="*80)
print("ACCURACY EVALUATION")
print("="*80)
print(f"Total images evaluated: {len(results_df)}")
print(f"Images with ground truth: {results_df['true_label'].notna().sum()}")

print(f"\n--- IMPRINT ACCURACY (RNN Output) ---")
print(f"Exact Match Accuracy: {top1_accuracy:.4f} ({top1_correct}/{len(results_df)})")

print(f"\n--- MEDICINE NAME ACCURACY (Database Retrieval) ---")
print(f"Top-1 Accuracy: {med_top1_accuracy:.4f} ({med_top1_correct}/{len(results_df)})")
print(f"Top-3 Accuracy: {med_top3_accuracy:.4f} ({med_top3_correct}/{len(results_df)})")
print("="*80)

# Show examples of correct and incorrect predictions
print("\n" + "-"*80)
print("CORRECT IMPRINT PREDICTIONS")
print("-"*80)
correct_preds = results_df[results_df['true_label'] == results_df['pred_imprint']].head(5)
for i, row in correct_preds.iterrows():
    print(f"✓ {row['image_name']}")
    print(f"  True: {row['true_label']}")
    print(f"  Pred: {row['pred_imprint']}")
    print(f"  Medicine: {row['true_medicine']}")
    print()

print("-"*80)
print("INCORRECT IMPRINT PREDICTIONS")
print("-"*80)
incorrect_preds = results_df[results_df['true_label'] != results_df['pred_imprint']].head(5)
for i, row in incorrect_preds.iterrows():
    print(f"✗ {row['image_name']}")
    print(f"  True: {row['true_label']}")
    print(f"  Pred: {row['pred_imprint']}")
    print(f"  Medicine: {row['true_medicine']}")
    print()


ACCURACY EVALUATION
Total images evaluated: 797
Images with ground truth: 797

--- IMPRINT ACCURACY (RNN Output) ---
Exact Match Accuracy: 0.2760 (220/797)

--- MEDICINE NAME ACCURACY (Database Retrieval) ---
Top-1 Accuracy: 0.5420 (432/797)
Top-3 Accuracy: 0.6863 (547/797)

--------------------------------------------------------------------------------
CORRECT IMPRINT PREDICTIONS
--------------------------------------------------------------------------------
✓ 51862-0172-12_NLMIMAGE10_55462AE1.jpg
  True: 172
  Pred: 172
  Medicine: Fluoride

✓ CBR07490.jpg
  True: I22
  Pred: I22
  Medicine: Linezolid

✓ 547380902.jpg
  True: MP112
  Pred: MP112
  Medicine: Sulindac

✓ 422910367.jpg
  True: J246
  Pred: J246
  Medicine: Lamotrigine

✓ 00074-6215-13_00210038.jpg
  True: ANS
  Pred: ANS
  Medicine: Depakote

--------------------------------------------------------------------------------
INCORRECT IMPRINT PREDICTIONS
------------------------------------------------------------------

## 11. Save Results

In [None]:
results_df.to_csv('../pipeline_results.csv', index=False)
print(f"✓ Results saved to pipeline_results.csv ({len(results_df)} rows)")

print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)
print(f"Total images processed: {len(results_df)}")
print(f"Average top-1 score: {results_df['top1_score'].mean():.3f}")
print(f"Average top-3 score: {results_df[['top1_score', 'top2_score', 'top3_score']].mean().mean():.3f}")
print(f"\nImprint Exact Match Accuracy: {top1_accuracy:.4f}")
print(f"Medicine Top-1 Accuracy: {med_top1_accuracy:.4f}")
print(f"Medicine Top-3 Accuracy: {med_top3_accuracy:.4f}")
print(f"\nScore distribution:")
print(results_df['top1_score'].describe())