In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pytesseract
from tqdm import tqdm

test_df = pd.read_csv("data/pillbox_heldout.csv")
print(f"Loaded {len(test_df)} samples")

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

test_df.head()

Loaded 1959 samples
Using device: mps


Unnamed: 0.1,Unnamed: 0,ID,splcolor_text,splshape_text,product_code,medicine_name,dosage_form,splimage,splimprint,original_name,num_colors,num_imprints,image_path,shape_label,color_label
0,6827,23538,WHITE,ROUND,23155-103,METFORMIN HYDROCHLORIDE,C42998,23155-0103-10_NLMIMAGE10_793C3C81,H;103,23155-0103-10_NLMIMAGE10_793C3C81.jpg,1,2.0,data/pillbox_production_images_full_202008/231...,10,10
1,9397,33762,WHITE,ROUND,60429-763,CILOSTAZOL,C42998,604290763,cor;159,604290763.jpg,1,2.0,data/pillbox_production_images_full_202008/604...,10,10
2,9392,3836,YELLOW,OVAL,43353-338,PANTOPRAZOLE SODIUM,C42905,433530338,P;40,433530338.jpg,1,2.0,data/pillbox_production_images_full_202008/433...,7,11
3,1255,5656,PINK,TRAPEZOID,68462-104,FLUCONAZOLE,C42998,684620104,200,684620104.jpg,1,1.0,data/pillbox_production_images_full_202008/684...,13,6
4,7161,19422,WHITE,ROUND,0185-4350,ISONIAZID,C42998,00185-4350-01_6B083581,E;4350,00185-4350-01_6B083581.jpg,1,2.0,data/pillbox_production_images_full_202008/001...,10,10


# Load in pre-trained Resnet Model

In [2]:
class MultiTaskResNet(nn.Module):
    def __init__(self, num_shapes, num_colors):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.shape_head = nn.Linear(in_features, num_shapes)
        self.color_head = nn.Linear(in_features, num_colors)
    
    def forward(self, x):
        features = self.backbone(x)
        shape_out = self.shape_head(features)
        color_out = self.color_head(features)
        return shape_out, color_out

checkpoint = torch.load('resnet_model/pill_classifier_full.pth', weights_only=False)

model = MultiTaskResNet(
    checkpoint['num_shape_classes'],
    checkpoint['num_color_classes']
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

shape_encoder = checkpoint['shape_encoder']
color_encoder = checkpoint['color_encoder']

print("ResNet model loaded")
print(f"Shape classes: {len(shape_encoder.classes_)}")
print(f"Color classes: {len(color_encoder.classes_)}")

ResNet model loaded
Shape classes: 15
Color classes: 12


In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict_pill(image_path):
    """Predict shape, color, and extract text from pill image"""
    image = Image.open(image_path).convert('RGB')
    
    # Vision prediction
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        shape_out, color_out = model(image_tensor)
        pred_shape_idx = shape_out.argmax(dim=1).item()
        pred_color_idx = color_out.argmax(dim=1).item()
    
    pred_shape = shape_encoder.inverse_transform([pred_shape_idx])[0]
    pred_color = color_encoder.inverse_transform([pred_color_idx])[0]
    
    # OCR text extraction
    pred_text = pytesseract.image_to_string(image, config='--psm 6').strip()
    
    return pred_shape, pred_color, pred_text

def predict_pill_timed(image_path):
    """Predict shape, color, and extract text with timing breakdown"""
    import time
    
    image = Image.open(image_path).convert('RGB')
    
    # Vision prediction (ResNet)
    resnet_start = time.perf_counter()
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        shape_out, color_out = model(image_tensor)
        pred_shape_idx = shape_out.argmax(dim=1).item()
        pred_color_idx = color_out.argmax(dim=1).item()
    pred_shape = shape_encoder.inverse_transform([pred_shape_idx])[0]
    pred_color = color_encoder.inverse_transform([pred_color_idx])[0]
    resnet_time = time.perf_counter() - resnet_start
    
    # OCR text extraction
    ocr_start = time.perf_counter()
    pred_text = pytesseract.image_to_string(image, config='--psm 6').strip()
    ocr_time = time.perf_counter() - ocr_start
    
    return pred_shape, pred_color, pred_text, resnet_time, ocr_time

In [4]:
def levenshtein_distance(s1, s2):
    """Calculate Levenshtein distance between two strings"""
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)
    if len(s2) == 0:
        return len(s1)
    
    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]

def count_overlapping_chars(s1, s2):
    """Count overlapping characters between two strings"""
    from collections import Counter
    counter1 = Counter(s1.lower())
    counter2 = Counter(s2.lower())
    overlap = sum((counter1 & counter2).values())
    return overlap

def similarity_score(prediction, database_pill):
    """
    Calculate similarity score combining shape, color, edit distance, and character overlap
    Max score: 3.0 (0.5 + 0.5 + 1.0 + 1.0)
    """
    # Feature scores (1/2 each for exact match)
    shape_score = 0.5 if prediction['shape'] == database_pill['shape'] else 0
    color_score = 0.5 if prediction['color'] == database_pill['color'] else 0
    
    # Text similarity (normalized edit distance)
    pred_text = str(prediction['text']).upper()
    db_text = str(database_pill['text']).upper()
    
    edit_dist = levenshtein_distance(pred_text, db_text)
    total_len = max(len(pred_text), len(db_text))
    edit_score = (1 - edit_dist / total_len) if total_len > 0 else 0
    
    # Character overlap
    overlap = count_overlapping_chars(pred_text, db_text)
    overlap_score = (overlap * 2) / (len(pred_text) + len(db_text)) if (len(pred_text) + len(db_text)) > 0 else 0
    
    # Total (2 features + 2 text components)
    total = shape_score + color_score + edit_score + overlap_score
    return total

# Retrieval

In [5]:
all_pills_df = pd.read_csv('data/pillbox.csv')
print(f"Database size: {len(all_pills_df)} pills")

def retrieve_top_k(prediction, database, k=3):
    """
    Fast retrieval with 2-stage soft filtering
    
    Stage 1: Filter to candidates with matching shape OR color
    Stage 2: Full similarity scoring on candidates only
    """
    # Stage 1: Soft filter - match shape OR color
    mask = (
        (database['splshape_text'] == prediction['shape']) |
        (database['splcolor_text'] == prediction['color'])
    )
    candidates = database[mask]
    
    # Fallback: if too few candidates, expand search
    if len(candidates) < k * 10:
        candidates = database
    
    # Stage 2: Full similarity scoring on candidates only
    scores = []
    for idx in range(len(candidates)):
        pill = candidates.iloc[idx]
        
        db_pill = {
            'shape': pill['splshape_text'],
            'color': pill['splcolor_text'],
            'text': pill['splimprint']
        }
        
        score = similarity_score(prediction, db_pill)
        scores.append((pill['medicine_name'], score))
    
    scores.sort(key=lambda x: x[1], reverse=True)
    return scores[:k]

Database size: 83925 pills


  all_pills_df = pd.read_csv('data/pillbox.csv')


In [6]:
import time

results = []
total_resnet_time = 0
total_ocr_time = 0
total_retrieval_time = 0

start_time = time.perf_counter()

for idx in tqdm(range(len(test_df)), desc="Processing pills"):
    row = test_df.iloc[idx]
    image_path = row['image_path']
    
    # Prediction with timing
    pred_shape, pred_color, pred_text, resnet_time, ocr_time = predict_pill_timed(image_path)
    
    total_resnet_time += resnet_time
    total_ocr_time += ocr_time
    
    # Create prediction dict for retrieval
    prediction = {
        'shape': pred_shape,
        'color': pred_color,
        'text': pred_text
    }
    
    # Retrieve top 3 matches
    retrieval_start = time.perf_counter()
    top_3 = retrieve_top_k(prediction, all_pills_df, k=3)
    retrieval_time = time.perf_counter() - retrieval_start
    total_retrieval_time += retrieval_time
    
    results.append({
        'image_path': image_path,
        'true_medicine': row['medicine_name'],
        'true_shape': row['splshape_text'],
        'true_color': row['splcolor_text'],
        'true_text': row['splimprint'],
        'pred_shape': pred_shape,
        'pred_color': pred_color,
        'pred_text': pred_text,
        'top1_medicine': top_3[0][0],
        'top1_score': top_3[0][1],
        'top2_medicine': top_3[1][0],
        'top2_score': top_3[1][1],
        'top3_medicine': top_3[2][0],
        'top3_score': top_3[2][1]
    })

total_time = time.perf_counter() - start_time
avg_time_per_image = total_time / len(test_df)

results_df = pd.DataFrame(results)

print("\n" + "=" * 60)
print("PROCESSING COMPLETE - TIMING BREAKDOWN")
print("=" * 60)
print(f"Total samples: {len(test_df)}")
print(f"\nTotal time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f"Average per image: {avg_time_per_image:.3f}s")
print(f"Throughput: {len(test_df)/total_time:.2f} images/sec")
print("\n" + "-" * 60)
print("BREAKDOWN BY COMPONENT:")
print("-" * 60)
print(f"ResNet inference:  {total_resnet_time:.2f}s  (avg: {total_resnet_time/len(test_df)*1000:.1f}ms/image) - {total_resnet_time/total_time*100:.1f}%")
print(f"OCR (Tesseract):   {total_ocr_time:.2f}s  (avg: {total_ocr_time/len(test_df)*1000:.1f}ms/image) - {total_ocr_time/total_time*100:.1f}%")
print(f"Database retrieval: {total_retrieval_time:.2f}s  (avg: {total_retrieval_time/len(test_df)*1000:.1f}ms/image) - {total_retrieval_time/total_time*100:.1f}%")
print("=" * 60)

results_df.head(10)

Processing pills: 100%|██████████| 1959/1959 [1:32:46<00:00,  2.84s/it]  


PROCESSING COMPLETE - TIMING BREAKDOWN
Total samples: 1959

Total time: 5566.35s (92.77 min)
Average per image: 2.841s
Throughput: 0.35 images/sec

------------------------------------------------------------
BREAKDOWN BY COMPONENT:
------------------------------------------------------------
ResNet inference:  81.89s  (avg: 41.8ms/image) - 1.5%
OCR (Tesseract):   388.00s  (avg: 198.1ms/image) - 7.0%
Database retrieval: 5082.17s  (avg: 2594.3ms/image) - 91.3%





Unnamed: 0,image_path,true_medicine,true_shape,true_color,true_text,pred_shape,pred_color,pred_text,top1_medicine,top1_score,top2_medicine,top2_score,top3_medicine,top3_score
0,data/pillbox_production_images_full_202008/231...,METFORMIN HYDROCHLORIDE,ROUND,WHITE,H;103,ROUND,WHITE,"1\n\n1\n\n1""\n\n1\n\n9\n\n8\n\n7\n\n6\n\n5\n\n...",Olmesartan Medoxomil,1.344474,Hydroxyzine Hydrochloride,1.328437,Phendimetrazine Tartrate,1.31
1,data/pillbox_production_images_full_202008/604...,CILOSTAZOL,ROUND,WHITE,cor;159,ROUND,WHITE,@ ©\nmm)! TT TTT,Tolterodine Tartrate,1.5,Tolterodine Tartrate,1.5,Tolterodine Tartrate,1.5
2,data/pillbox_production_images_full_202008/433...,PANTOPRAZOLE SODIUM,OVAL,YELLOW,P;40,OVAL,YELLOW,,Risperidone,1.0,SEVERE COLD AND FLU RELIEF,1.0,guanfacine,1.0
3,data/pillbox_production_images_full_202008/684...,FLUCONAZOLE,TRAPEZOID,PINK,200,TRAPEZOID,PINK,,Fluconazole,1.0,Fluconazole,1.0,Diflucan,1.0
4,data/pillbox_production_images_full_202008/001...,ISONIAZID,ROUND,WHITE,E;4350,ROUND,WHITE,"12:\n\nEi ""\n9\n\n8\n\n7\n\n6\n\n5\n\n4\n\n3\n...",Trazodone Hydrochloride,1.235511,Trazodone Hydrochloride,1.235511,Trazodone Hydrochloride,1.235511
5,data/pillbox_production_images_full_202008/669...,DOXERCALCIFEROL,OVAL,YELLOW,g,ROUND,YELLOW,,Butalbital and Acetaminophen,1.0,Solifenacin Succinate,1.0,Stay Awake,1.0
6,data/pillbox_production_images_full_202008/290...,THEOPHYLLINE,ROUND,WHITE,N;T4,ROUND,WHITE,"""\n9\n8\n7\n6\na\n4\n3\n2\n1\n0\nmm\nt) 1 2 3 ...",Actoplus Met,1.28046,Trazodone Hydrochloride,1.270936,Trazodone Hydrochloride,1.270936
7,data/pillbox_production_images_full_202008/001...,SERTRALINE,OVAL,YELLOW,I;G;214,OVAL,YELLOW,,Risperidone,1.0,SEVERE COLD AND FLU RELIEF,1.0,guanfacine,1.0
8,data/pillbox_production_images_full_202008/656...,XIFAXAN,ROUND,PINK,Sx,ROUND,RED,aR: i ats\nRo Bee nie\nSE Hage a Ue as eee ert...,Amlodipine and Olmesartan Medoxomil,1.346,TYLENOL,1.229292,Tranylcypromine Sulfate,1.225455
9,data/pillbox_production_images_full_202008/no_...,HYDRASTIS CAN,ROUND,WHITE,,ROUND,WHITE,No Image\nAvailable,Norgesic,1.691571,Gaviscon,1.683761,Gaviscon,1.683761


# Evaluation

In [7]:
# Calculate Top-1 and Top-3 Accuracy
top1_correct = (results_df['true_medicine'] == results_df['top1_medicine']).sum()
top3_correct = (
    (results_df['true_medicine'] == results_df['top1_medicine']) |
    (results_df['true_medicine'] == results_df['top2_medicine']) |
    (results_df['true_medicine'] == results_df['top3_medicine'])
).sum()

total = len(results_df)
top1_accuracy = top1_correct / total
top3_accuracy = top3_correct / total

print("=" * 50)
print("RETRIEVAL EVALUATION RESULTS")
print("=" * 50)
print(f"Total samples: {total}")
print(f"\nTop-1 Accuracy: {top1_accuracy:.4f} ({top1_correct}/{total})")
print(f"Top-3 Accuracy: {top3_accuracy:.4f} ({top3_correct}/{total})")
print("=" * 50)

RETRIEVAL EVALUATION RESULTS
Total samples: 1959

Top-1 Accuracy: 0.0031 (6/1959)
Top-3 Accuracy: 0.0077 (15/1959)
