In [1]:
import torch
from torch import nn
import timm
from torchvision import transforms
import pickle
import torch.nn.functional as F

In [2]:
with open('pill_model/traintestsplit.pkl', 'rb') as splits:
    traintestsplit = pickle.load(splits)

train_paths = traintestsplit['train_paths']
train_labels = traintestsplit['train_labels']
test_paths = traintestsplit['test_paths']
test_labels = traintestsplit['test_labels']
class_to_label = traintestsplit['class_to_label']
label_to_class = {class_to_label[c]: c for c in class_to_label}

In [85]:
class_to_label

{'Aciclovir_Herpex_800_mg': 0,
 'Ascorbic_Acid_BewellC': 1,
 'Azithromycin_as_dihydrate_Zithromax': 2,
 'Carbocisteine_Marluxyn': 3,
 'Carbocisteine_Solmux': 4,
 'DoloNeurobion_3B': 5,
 'Iron__Pyridoxine__Cyanacobalamin__Folic_Acid_Hemarate': 6,
 'KremilS': 7,
 'Loperamide_Hydrochloride_Diatabs': 8,
 'Loratadine_Allerta': 9,
 'Loratadine_Claritin': 10,
 'Losartan_Potassium_Medzart': 11,
 'MX3': 12,
 'Meclizine_Hydrochloride_Bonamine': 13,
 'Mefenamic_Acid_Myrefen': 14,
 'Montelukast_as_sodium__Levocetirizine_hydrochloride_Allerkast': 15,
 'Multivatimins__Minerals_Centrum_Advance': 16,
 'Naproxen_Sodium_Skelan_550': 17,
 'Paracetamol_Biogesic': 18,
 'Paracetamol_Tempaid': 19,
 'Paracetamol_Tempra': 20,
 'Phenylephrine_Hydrochloride__Chlorphenamine_Maleate__Paracetamol_Bioflu': 21,
 'Rosuvastatin_Rosusaph10': 22,
 'Sambong_Leaf_Uricare_500mg': 23,
 'Sinecod_Forte': 24,
 'Sodium_Ascorbate__Zinc_ImmunPro': 25,
 'Vitex_Negundo_L_Lagundi_Leaf_Ascof_Forte': 26,
 'Wild_Alaskan_Fish_Oil_1400_mg

In [87]:
import fuzzywuzzy

In [None]:
# pd_drug_names = list(drug_list['Name'])
# c2l_names = list(class_to_label.keys())

# from fuzzywuzzy import process

# def fuzzy_match_lists(reference_list, query_list, threshold=80):
#     matched_list = []
#     for query in query_list:
#         match, score = process.extractOne(query, reference_list)
#         if score >= threshold:
#             matched_list.append(match)
#         else:
#             matched_list.append(None)
#     return matched_list

# matched_names = fuzzy_match_lists(c2l_names, pd_drug_names, 70)

In [96]:
# for i in range(len(matched_names)):
#     print(matched_names[i] + ' === ' + pd_drug_names[i])

In [27]:
base_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

In [4]:
class ContrastiveClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super().__init__()

        self.encoder = timm.create_model('convnextv2_nano', pretrained=False)
    
        self.encoder.head.fc = nn.Identity()
        
        self.classifier_head = nn.Sequential(
            nn.Linear(640, embedding_dim),
            nn.LayerNorm(embedding_dim, eps=1e-05, elementwise_affine=True),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, num_classes, bias=True)  # Output contrastive embedding
        )
        self.projection_head = nn.Sequential(
            nn.Linear(640, embedding_dim),
            nn.LayerNorm(embedding_dim, eps=1e-05, elementwise_affine=True),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, embedding_dim, bias=True)  # Output contrastive embedding
        )

    def forward(self, x):
        features = self.encoder(x)
        embeddings = self.projection_head(features)
        logits = self.classifier_head(features)
        return embeddings, logits

In [5]:
OUT_FEATURES = 512
num_classes = 28
model = ContrastiveClassifier(OUT_FEATURES, num_classes)

In [6]:
PATH = 'pill_model/model_last.pt'
state_dict = torch.load(PATH, weights_only=False, map_location='cpu')
model.load_state_dict(state_dict['model_state_dict'])
model.eval()
print("Loaded")

Loaded


In [7]:
support_set = torch.load('pill_model/support_embeddings.pt', map_location='cpu')
mean_embeddings = support_set['mean_embeddings']
support_embeddings = torch.vstack(list(mean_embeddings.values()))
labels = list(mean_embeddings.keys())

  support_set = torch.load('pill_model/support_embeddings.pt', map_location='cpu')


In [8]:
embedding = torch.randn(1, 512)
scores = F.cosine_similarity(support_embeddings, embedding)
top_values, top_indices = scores.topk(3)
top_labels = [labels[i] for i in top_indices]
top_labels

[1, 16, 7]

In [54]:
torch.__version__

'2.4.0+cpu'

In [17]:
from dis_bg_remover import remove_background
import cv2
import numpy as np
from PIL import Image

In [127]:
bg_model_path = "pill_model/isnet_dis.onnx"
input_img_path = 'test_pills\LoratadineClaritin.jpg'

def preprocess_image(input_img_path):
    img, mask = remove_background(bg_model_path, input_img_path)
                    
    # Handle potential differences in dimensions
    if img.shape[2] == 4:  # BGRA image
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to BGR

    # Ensure mask is 2D
    if len(mask.shape) > 2:
        mask = mask[:,:,0]  # Take first channel if mask has multiple channels

    bg_color = 128
    # Create gray background
    h, w = img.shape[:2]
    gray_bg = np.ones((h, w, 3), dtype=np.uint8) * np.array(bg_color, dtype=np.uint8)

    # Convert binary mask to 3-channel mask for multiplication
    # Ensure mask has values between 0 and 1
    mask = mask.astype(np.float32) / 255.0 if np.max(mask) > 1 else mask.astype(np.float32)
    mask_3channel = np.stack([mask, mask, mask], axis=2)
    # Blend foreground with gray background using the mask
    result = (img * mask_3channel + gray_bg * (1 - mask_3channel)).astype(np.uint8)
    result = Image.fromarray(result)

    return result

@torch.no_grad()
def predict(image):
    preprocessed_image_tensor = base_transform(image).unsqueeze(0)
    pred_embedding, _ = model(preprocessed_image_tensor)

    scores = F.cosine_similarity(support_embeddings, pred_embedding)
    top_values, top_indices = scores.topk(3)
    top_labels = [labels[i] for i in top_indices]
    top_predictions = [label_to_class[i] for i in top_labels]
    return top_predictions, [round(i, 2) for i in top_values.numpy().tolist()]


In [140]:
image = preprocess_image('test_pills/Bewell-C.jpg')
predictions, scores = predict(image)
predictions

['Ascorbic_Acid_BewellC', 'MX3', 'Carbocisteine_Marluxyn']

In [110]:
import pandas as pd

fda_df = pd.read_csv('FDA_ALL.csv')
drug_list = pd.read_csv("Medications List Clean.csv")

In [148]:
list_entry = drug_list[drug_list['Name'] == 'MX3'].iloc[0]
registration_num = list_entry['FDA Link'].split("=")[-1]
fda_df[fda_df['Registration Number'] == registration_num]

Unnamed: 0,INDEX,Registration Number,Generic Name,Brand Name,Dosage Strength,Dosage Form,Classification,Packaging,Pharmacologic Category,Manufacturer,Country of Origin,Trader,Importer,Distributor,Application Type,Issuance Date,Expiry Date


In [151]:
def get_pill_info(drug_name):
    list_entry = drug_list[drug_list['Name'] == drug_name].iloc[0]
    if 'drug_products' not in list_entry['FDA Link']:
        return 'Not Drug'
    registration_num = list_entry['FDA Link'].split("=")[-1]
    pill_entry = fda_df[fda_df['Registration Number'] == registration_num].iloc[0]
    pill_entry = dict(pill_entry)
    return pill_entry


In [164]:
from medication_matching import match_with_rx
drugdata_df = pd.read_csv('drug_data.csv')

In [None]:
pill_pred_info = {'matches': []}
for p, s in zip(predictions, scores):
    pill_info = get_pill_info(p)
    if pill_info == 'Not Drug':
        continue
    pill_info['Score'] = s
    pill_pred_info['matches'].append(pill_info)
    matches, _ = match_with_rx(pill_info['Generic Name'])
    rx_info = dict(drugdata_df.iloc[matches[0]])
    pill_info['rx_info'] = rx_info

## API Testing

In [184]:
PIL.__version__

'10.4.0'

In [1]:
import requests, json

url = "https://fastapi-app-613987678533.asia-southeast1.run.app/upload-pill-image/"
files = {"file": open("test_pills/Bewell-C.jpg", "rb")}

response = requests.post(url, files=files)
response.json()

{'matches': [{'INDEX': 2351,
   'Registration Number': 'DRHR-1355',
   'Generic Name': 'Ascorbic Acid',
   'Brand Name': 'Bewell C',
   'Dosage Strength': '500 mg',
   'Dosage Form': 'Capsule',
   'Classification': 'Household Remedy (HR)',
   'Packaging': "Alu-Red PVC Blister Pack x 10's (Box of 100's and 200's)",
   'Pharmacologic Category': '-',
   'Manufacturer': 'Lejal Laboratories Inc.',
   'Country of Origin': 'Philippines',
   'Trader': 'Bewell Nutraceutical Corp.',
   'Importer': None,
   'Distributor': None,
   'Application Type': '-',
   'Issuance Date': '24-Jul-20',
   'Expiry Date': '28-Aug-25',
   'Score': 0.85,
   'rx_info': {'Name': 'Ascorbic Acid (Vitamin C)',
    'URL': 'https://www.rxlist.com/ascorbic-acid-drug.htm',
    'What is': '\nAscorbic Acid (vitamin C) is a water-soluble vitamin recommended for the prevention and treatment of scurvy. Ascorbic acid is available in generic form.',
    'What Are Side Effects': '\nCommon side effects of ascorbic acid include trans

In [177]:
import requests, json

url = "http://192.168.254.118:8000/upload-pill-image/"
files = {"file": open("test_pills/Bewell-C.jpg", "rb")}

response = requests.post(url, files=files)
response.json()

{'matches': [{'INDEX': 2351,
   'Registration Number': 'DRHR-1355',
   'Generic Name': 'Ascorbic Acid',
   'Brand Name': 'Bewell C',
   'Dosage Strength': '500 mg',
   'Dosage Form': 'Capsule',
   'Classification': 'Household Remedy (HR)',
   'Packaging': "Alu-Red PVC Blister Pack x 10's (Box of 100's and 200's)",
   'Pharmacologic Category': '-',
   'Manufacturer': 'Lejal Laboratories Inc.',
   'Country of Origin': 'Philippines',
   'Trader': 'Bewell Nutraceutical Corp.',
   'Importer': None,
   'Distributor': None,
   'Application Type': '-',
   'Issuance Date': '24-Jul-20',
   'Expiry Date': '28-Aug-25',
   'Score': 0.85,
   'rx_info': {'Name': 'Ascorbic Acid (Vitamin C)',
    'URL': 'https://www.rxlist.com/ascorbic-acid-drug.htm',
    'What is': '\nAscorbic Acid (vitamin C) is a water-soluble vitamin recommended for the prevention and treatment of scurvy. Ascorbic acid is available in generic form.',
    'What Are Side Effects': '\nCommon side effects of ascorbic acid include trans