In [1]:
import pandas as pd
import numpy as np
import re

import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
from sentence_transformers import SentenceTransformer

### Filling packing_form column 

In [2]:
def determine_packaging(salt):
    """Determine packaging form based on dosage information in salts"""
    if pd.isna(salt):
        return np.nan
    
    patterns = [
        (r'\b(mg|g|mcg)\b(?!.*/|\d+ml)', ['STRIP', 'Strip | Tablet', 'Strip TABLET', 'Strip | Capsule']),
        (r'\b(mg|g|mcg)/tablet\b', ['Strip | Tablet', 'STRIP']),
        (r'\b(mg|g|mcg)/ml\b', ['BOTTLE', 'Bottle | Suspension', 'Bottle SYRUP', 'Bottle | Syrup']),
        (r'\b%\s*(w/v|w/w)?\b', ['TUBE', 'BOTTLE', 'Bottle | Lotion']),
        (r'\b(iu|units?)/ml?\b', ['VIAL', 'AMPOULE', 'Vial | Injection']),
        (r'\b(ml|litre)\b', ['BOTTLE', 'Bottle | Solution', 'Bottle | Liquid']),
        (r'\b(injection|inj)\b', ['AMPOULE', 'Vial | Injection', 'INJECTION']),
        (r'\b(cream|ointment)\b', ['TUBE', 'Tube | Cream']),
        (r'\b(drops|solution)\b', ['Bottle | Drops', 'Bottle | Solution']),
        (r'\b(gel|shampoo)\b', ['Tube | Gel', 'Bottle | Shampoo']),
        (r'\b(powder|sachet)\b', ['SACHET', 'Packet | Suspension']),
    ]
    

    for pattern, packaging_options in patterns:
        if re.search(pattern, str(salt).lower()):

            for option in packaging_options:
                if option in valid_packaging:
                    return option
                

    return np.random.choice(['STRIP', 'BOTTLE', 'TUBE'], p=[0.5, 0.3, 0.2])


valid_packaging = [
    'STRIP', 'Strip | Tablet', 'Bottle | Suspension', 'BOTTLE',
    'Bottle SYRUP', 'Bottle | Syrup', 'Strip TABLET', 'AMPOULE',
    'Bottle | Infusion', 'Bottle | Oral Drops', 'Bottle | Spray',
    'INJECTION', 'TUBE', 'Vial | Injection', 'VIAL', 'Tube | Cream',
    'Bottle | Nasal Spray', 'Bottle | Nasal Drops', 'Tube | Gel',
    'Box | Injection', 'Tube | Mouth Gel', 'Tube GEL',
    'Bottle | Eye Drops', 'Vial INJECTION', 'Strip',
    'Bottle | Oral Suspension', 'Strip | Capsule', 'Tube',
    'Ampoule | Injection', 'Tube | Ointment', 'Packet | Injection',
    'Bottle', 'Bottle | Drops', 'Bottle | Solution', 'DRY VIAL',
    'Bottle | Oral Solution', 'Bottle | Eye/Ear Drops',
    'Bottle SUSPENSION', 'SACHET', 'Packet | Suspension', 'DEVICE',
    'BOX', 'Pack | Soap', 'Bottle | Dusting Powder', 'DROPS',
    'Tube | Shampoo', 'Bottle | Shampoo', 'CONTAINER',
    'Bottle | Lotion', 'PACK', 'PACKET', 'Bottle | Liquid',
    'Bottle | Tablet', 'Vial | Infusion', 'Bottle | Injection'
]

def fill_packaging_forms(df):
    """Main function to fill missing packaging forms"""

    df['salts'] = df['salts'].str.lower().fillna('')
    

    packaging_dist = df['packaging_form'].value_counts(normalize=True)
    

    mask = df['packaging_form'].isna()
    df.loc[mask, 'packaging_form'] = df.loc[mask, 'salts'].apply(
        lambda x: determine_packaging(x) if pd.notna(x) else np.random.choice(
            packaging_dist.index.tolist(), 
            p=packaging_dist.values.tolist()
        )
    )
    
    return df

In [None]:
class FormulationClusterer:
    def __init__(self):
        # Use a biomedical sentence transformer model
        self.model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        
    def preprocess(self, texts):
        """Standardize formulation strings"""
        processed = []
        for text in texts:
            text = str(text).lower().replace('%', ' percent').replace('/', ' per ')
            text = re.sub(r'[^a-z0-9\s]', ' ', text)
            text = re.sub(r'\s+', ' ', text).strip()
            processed.append(text)
        return processed
    
    def get_embeddings(self, texts):
        """Convert text to embeddings"""
        return self.model.encode(texts, convert_to_tensor=True).cpu().numpy()
    
    def cluster_formulations(self, formulations, threshold=0.3):
        """Perform hierarchical clustering"""
        cleaned = self.preprocess(formulations)
        embeddings = self.get_embeddings(cleaned)
        
        dist_matrix = 1 - cosine_similarity(embeddings)
        
        linkage_matrix = linkage(dist_matrix, 'average')
        
        clusters = fcluster(linkage_matrix, threshold, criterion='distance')
        
        return clusters, linkage_matrix

    def print_clusters(self, formulations, clusters):
        """Display clustered formulations"""
        cluster_dict = defaultdict(list)
        for formulation, cluster_id in zip(formulations, clusters):
            cluster_dict[cluster_id].append(formulation)
            
        for cluster_id, members in cluster_dict.items():
            print(f"\nCluster {cluster_id}:")
            print("-" * 40)
            print("\n".join(members))
            
df = pd.read_csv('/Users/sushanth/Desktop/submission/cleaned_medlr_dataset.csv')

df["salts"] = df["salts"].fillna("Unknown")
df = df[df["salts"] != "Unknown"]

sample_salts = df["salts"].unique().tolist()

clusterer = FormulationClusterer()
clusters, linkage_matrix = clusterer.cluster_formulations(sample_salts)

# Assign clusters to DataFrame
cluster_mapping = {salt: cluster for salt, cluster in zip(sample_salts, clusters)}
df["cluster"] = df["salts"].map(cluster_mapping)

# Print clustered results
clusterer.print_clusters(sample_salts, clusters)

In [3]:
df = pd.read_csv('cleaned_medlr_dataset.csv')
df = fill_packaging_forms(df)

sample_salts = list(df['salts'])

In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1448 entries, 0 to 1447
Data columns (total 10 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   id                     1448 non-null   object 
 1   name                   1448 non-null   object 
 2   source                 1448 non-null   object 
 3   prescription_required  1448 non-null   bool   
 4   retail_price           1367 non-null   float64
 5   discounted_price       1384 non-null   float64
 6   manufacturer           1446 non-null   object 
 7   quantity               1184 non-null   object 
 8   packaging_form         1448 non-null   object 
 9   salts                  1448 non-null   object 
dtypes: bool(1), float64(2), object(7)
memory usage: 103.4+ KB
