In [66]:
import torch
# from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import sklearn
import sklearn.model_selection
# import re
# import threading
# from multiprocessing import Pool
import os
# import xgboost as xgb
# from sklearn.model_selection import GridSearchCV
import random
from tqdm import tqdm

from sentence_transformers import SentenceTransformer, util

In [2]:
def set_random_seed(seed=42):
    """
    Sets the random seed for Python, NumPy, and PyTorch to ensure reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # If using a GPU (CUDA backend), ensure deterministic behavior if needed
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Usage
set_random_seed(42)

<br>
<br>
<br>
<br>

<h1>Dataset:<h1>

In [3]:
path_patient_data_csv = "../../synthea/silver/updated_patient_data.csv"

path_providers_data_csv = "../../synthea/providers.csv"

temp_train_df = pd.read_csv(path_patient_data_csv)

kktemp_train_df = temp_train_df.drop(["SPECIALITY"], axis = 1)

In [108]:
tmp = pd.read_csv(path_patient_data_csv)
# tmp = tmp["SPECIALITY"]

# Select the 'SPECIALITY' column and drop duplicates
unique_specialities = tmp["condition"].drop_duplicates()

# Save the unique specialities to a new CSV file
unique_specialities.to_csv("unique_conditions.csv", index=False)


In [5]:
temp = temp_train_df[["condition", "SPECIALTY"]]
temp.to_csv("temp.csv", index = False)

In [6]:
print(len(temp_train_df))

8395


In [7]:
temp_train_df.head()

Unnamed: 0,patient_id,BIRTHDATE,GENDER,RACE,ETHNICITY,ADDRESS,CITY,STATE,COUNTY,ZIP,condition_start,condition_stop,encounter,condition,encounter_type,reason_id,reason,provider_id,SPECIALTY
0,1d604da9-9a81-4ba9-80c2-de3375d59b40,1989-05-25,M,white,hispanic,575 BEECH STREET,HOLYOKE,MA,Hampden County,1040,2019-03-20,2019-04-10,4e595f0c-f50f-461b-a04e-13b4e492350e,Viral sinusitis (disorder),Encounter for symptom,444814009.0,Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
1,1d604da9-9a81-4ba9-80c2-de3375d59b40,1989-05-25,M,white,hispanic,575 BEECH STREET,HOLYOKE,MA,Hampden County,1040,2011-12-08,2011-12-22,792fae81-a007-44b0-8221-46953737b089,Viral sinusitis (disorder),Encounter for symptom,444814009.0,Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
2,1d604da9-9a81-4ba9-80c2-de3375d59b40,1989-05-25,M,white,hispanic,575 BEECH STREET,HOLYOKE,MA,Hampden County,1040,2001-05-01,,8f104aa7-4ca9-4473-885a-bba2437df588,Chronic sinusitis (disorder),Encounter for symptom,36971009.0,Sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
3,034e9e3b-2def-4559-bb2a-7850888ae060,1983-11-14,F,white,nonhispanic,1493 CAMBRIDGE STREET,CAMBRIDGE,MA,Middlesex County,2138,2016-12-29,2017-01-05,3b639086-5fbc-4720-8c31-e8c8c0f1d660,Acute bronchitis (disorder),Encounter for symptom,10509002.0,Acute bronchitis (disorder),e6283e46-fd81-3611-9459-0edb1c3da357,Pulmonologist
4,10339b10-3cd1-4ac3-ac13-ec26728cb592,1992-06-02,M,white,nonhispanic,575 BEECH STREET,HOLYOKE,MA,Hampden County,1040,2019-04-23,2019-05-07,27ff7518-6d93-4308-8a1d-d2dfb02c0c58,Acute bronchitis (disorder),Encounter for symptom,10509002.0,Acute bronchitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Pulmonologist


In [8]:
temp_train_df = temp_train_df.drop(columns=["BIRTHDATE", "ADDRESS", "CITY", "STATE", "COUNTY", "ZIP", "condition_start", "condition_stop", "encounter", "encounter_type", "reason_id"], axis = 1)
temp_train_df.head()

Unnamed: 0,patient_id,GENDER,RACE,ETHNICITY,condition,reason,provider_id,SPECIALTY
0,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Viral sinusitis (disorder),Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
1,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Viral sinusitis (disorder),Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
2,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Chronic sinusitis (disorder),Sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
3,034e9e3b-2def-4559-bb2a-7850888ae060,F,white,nonhispanic,Acute bronchitis (disorder),Acute bronchitis (disorder),e6283e46-fd81-3611-9459-0edb1c3da357,Pulmonologist
4,10339b10-3cd1-4ac3-ac13-ec26728cb592,M,white,nonhispanic,Acute bronchitis (disorder),Acute bronchitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Pulmonologist


In [9]:
# temp_train_df = pd.get_dummies(temp_train_df, columns = ["patient_id"])
# temp_train_df.head()

In [10]:
counts = temp_train_df.groupby('SPECIALTY').size()
print(counts)

SPECIALTY
Addiction Specialist            21
Allergist                       94
Burn Specialist                 39
Cardiologist                   800
Dentist                         57
Dermatologist                   25
Endocrinologist               1054
Gastroenterologist             113
General Surgeon                259
Hematologist                   300
Nephrologist                    85
Neurologist                    366
Obstetrician                   889
Oncologist                      75
Ophthalmologist                 22
Orthopedic Specialist          542
Otolaryngologist              2625
Pain Management Specialist      55
Plastic Surgeon                 36
Psychiatrist                    25
Pulmonologist                  634
Rheumatologist                   9
Toxicologist                    52
Trauma Surgeon                   9
Urologist                      148
dtype: int64


In [11]:
unique_count = temp_train_df["patient_id"].unique()
len(unique_count)

1171

In [12]:
temp_train_df.head()

Unnamed: 0,patient_id,GENDER,RACE,ETHNICITY,condition,reason,provider_id,SPECIALTY
0,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Viral sinusitis (disorder),Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
1,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Viral sinusitis (disorder),Viral sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
2,1d604da9-9a81-4ba9-80c2-de3375d59b40,M,white,hispanic,Chronic sinusitis (disorder),Sinusitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Otolaryngologist
3,034e9e3b-2def-4559-bb2a-7850888ae060,F,white,nonhispanic,Acute bronchitis (disorder),Acute bronchitis (disorder),e6283e46-fd81-3611-9459-0edb1c3da357,Pulmonologist
4,10339b10-3cd1-4ac3-ac13-ec26728cb592,M,white,nonhispanic,Acute bronchitis (disorder),Acute bronchitis (disorder),af01a385-31d3-3c77-8fdb-2867fe88df2f,Pulmonologist


In [13]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path):
        super().__init__()
        
        # Load dataset
        df = pd.read_csv(dataset_path)

        # Convert specified columns to fixed-length strings and flatten
        train = np.array(df[["condition"]].values, dtype="U10").flatten()
        train_labels = np.array(df[["SPECIALTY"]].values, dtype="U10").flatten()

        # Split into train and test
        self.train, self.test, self.train_labels, self.test_labels = sklearn.model_selection.train_test_split(
            train, train_labels, test_size=0.20
        )

    def __len__(self):
        return len(self.train)

    def __getitem__(self, idx):
        # Return raw strings for samples and labels
        return self.train[idx], self.train_labels[idx]

In [14]:
# class MakeDatasetLoader(torch.utils.data.Dataset):
#     def __init__(self, samples, labels, non_trainable_model):
#         super().__init__()

#         self.samples = samples
#         self.labels = labels
#         # self.samples = torch.from_numpy(samples.astype(np.float32)).to(device=torch.device("cpu"), dtype=torch.float32)
#         # self.labels = torch.from_numpy(labels.astype(np.float32)).to(device=torch.device("cpu"), dtype=torch.float32)

#         # print(samples[0])

#         # self.labels = np.array(labels).tobytes()
#         # self.samples = np.array(samples).tobytes()
#         # print(labels)

#         # Store the labels as tensors
#         # self.labels = torch.tensor(labels, dtype=torch.float32).to(torch.device("cpu"))
        
#         # Convert each sample into embeddings using the non_trainable_model
#         # Assuming `samples` is a 1-column dataframe with text in each row
#         # self.samples = [non_trainable_model.encode(sample[0], convert_to_tensor=True).to(torch.device("cpu")) for sample in samples.values]
    
#     def __len__(self):
#         return len(self.samples)
    
#     def __getitem__(self, idx):
#         return self.samples[idx], self.labels[idx]

<br>
<br>
<br>
<br>

In [15]:
non_trainable_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')

In [16]:
# def custom_collate_fn(batch):
#     # Separate samples and labels
#     samples, labels = zip(*batch)
    
#     # Convert samples to embeddings
#     sample_embeddings = [non_trainable_model.encode(sample, convert_to_tensor=True) for sample in samples]
#     sample_embeddings = torch.stack(sample_embeddings)  # Stack embeddings into a single tensor
    
#     return sample_embeddings, list(labels)  # Return embeddings and raw labels


# dataset = CustomDataset(dataset_path = path_patient_data_csv)

# train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

In [27]:
# Define a custom collate function
def collate_fn(batch):
    # batch is a list of tuples (data, label)
    data, labels = zip(*batch)
    # Convert tuples to lists
    data = list(data)
    labels = list(labels)
    return data, labels

# Create an instance of your dataset
dataset = CustomDataset(path_patient_data_csv)

# Create a DataLoader with the custom collate function
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,        # Adjust batch_size as needed
    shuffle=False,
    collate_fn=collate_fn
)

print(dataset.__getitem__(0))

# Now you can iterate over the DataLoader
for batch_data, batch_labels in dataloader:
    print(batch_data)    # List of strings (conditions)
    print(batch_labels)  # List of strings (specialties)
    break
    # You can now process the batch_data and batch_labels as needed

(np.str_('Normal pre'), np.str_('Obstetrici'))
[np.str_('Normal pre'), np.str_('Acute bron'), np.str_('Escherichi'), np.str_('Body mass '), np.str_('Anemia (di'), np.str_('Prediabete'), np.str_('Chronic pa'), np.str_('Acute vira'), np.str_('Malignant '), np.str_('Concussion'), np.str_('Stroke'), np.str_('Acute vira'), np.str_('Perennial '), np.str_('Acute bron'), np.str_('Fracture o'), np.str_('Viral sinu'), np.str_('Familial A'), np.str_('Viral sinu'), np.str_('Tubal preg'), np.str_('Viral sinu'), np.str_('Sprain of '), np.str_('Idiopathic'), np.str_('Viral sinu'), np.str_('Osteoarthr'), np.str_('Viral sinu'), np.str_('Drug overd'), np.str_('Non-small '), np.str_('Viral sinu'), np.str_('Anemia (di'), np.str_('Laceration'), np.str_('Fetus with'), np.str_('Prediabete')]
[np.str_('Obstetrici'), np.str_('Pulmonolog'), np.str_('Urologist'), np.str_('Endocrinol'), np.str_('Hematologi'), np.str_('Endocrinol'), np.str_('Pain Manag'), np.str_('Otolaryngo'), np.str_('Oncologist'), np.str_('Neur

<br>

In [382]:




# train_data = MakeDatasetLoader(train_samples, train_labels, non_trainable_model = non_trainable_model)
# test_data = MakeDatasetLoader(test_samples, test_labels, non_trainable_model = non_trainable_model)

# train_dataloader = torch.utils.data.DataLoader(dataset=train_data, shuffle = True,batch_size = 2, collate_fn = custom_collate_fn)
# test_dataloader = torch.utils.data.DataLoader(dataset=test_data, shuffle = False,batch_size = 2, collate_fn = custom_collate_fn)

<br>
<br>

In [383]:
# model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')

# query = "Hello I am Ovidiu"

# query_emb = model.encode(query)

# print(query_emb)

In [104]:
# Model Definition
class SentenceTransformerWithHead(torch.nn.Module):
    def __init__(self, curr_device=torch.device("cpu")):
        super(SentenceTransformerWithHead, self).__init__()
        
        # Load the pretrained sentence transformer
        self.sentence_transformer = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
        
        self.head = torch.nn.Linear(768, 768)  # Adjust size if model’s output size is different

        self.device = curr_device
        self.to(self.device)
    
    def forward(self, input_text):
        # input_text is a list of strings
        # Obtain embeddings without computing gradients for the embedding model
        embeddings = self.sentence_transformer.encode(
            input_text,
            convert_to_tensor=True,
            device=self.device
        )
        # embeddings is a tensor on the correct device with requires_grad=False
        
        # Pass embeddings through the head
        output = self.head(embeddings)
        return output

# Checkpoint Saving Function
def save_checkpoint(epoch, model_trainable, optimizer, loss, checkpoint_path):
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)  # Create directory if it doesn't exist

    # Save model and optimizer state dictionaries, epoch, and loss
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_trainable.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

# Training Function
def train_SentenceWithHead(model_non_trainable, model_trainable, dataloader, num_epochs=5, learning_rate=0.001, checkpoint_dir="./checkpoint", device=torch.device("cpu")):
    optimizer = torch.optim.Adam(model_trainable.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()  # Mean Squared Error for Euclidean distance

    model_trainable.to(device)
    model_non_trainable.to(device)

    for epoch in range(num_epochs):
        model_trainable.train()
        total_loss = 0.0

        with tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as tepoch:
            for samples, labels in tepoch:
                # Ensure samples and labels are lists of strings
                samples = [str(sample) for sample in samples]
                labels = [str(label) for label in labels]

                # Move data to device (not necessary for strings)
                # Zero the gradients
                optimizer.zero_grad()

                # samples = np.array(samples, dtype = np.str_)
                # print(samples)
                # print(samples.shape)

                # Forward pass through the trainable model
                outputs = model_trainable(samples)  # Outputs shape: [batch_size, embedding_dim]

                # Generate target embeddings using the non-trainable model for labels
                with torch.no_grad():
                    target_embeddings = model_non_trainable.encode(labels)  # Shape: [batch_size, embedding_dim]
                    target_embeddings = torch.tensor(target_embeddings).to(device)

                # Compute loss
                loss = criterion(outputs, target_embeddings)

                # Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")

        # Save checkpoint at the end of each epoch
        # checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
        # save_checkpoint(epoch + 1, model_trainable, optimizer, avg_loss, checkpoint_path)
    
    torch.save(model_trainable.state_dict(), "./model-saves/trainable_state_dict.pth")
    model_non_trainable.save("./model-saves/non_trainable_model")
    


<br>
<br>
<br>
<br>
<h2>Train<h2>

In [105]:

trainable_model = SentenceTransformerWithHead(torch.device("mps"))
train_SentenceWithHead(model_non_trainable = non_trainable_model, model_trainable = trainable_model, dataloader = dataloader, checkpoint_dir = "./checkpoint", device = torch.device("mps"))

Epoch 1/5: 100%|██████████| 210/210 [00:19<00:00, 10.51batch/s]


Epoch 1/5, Loss: 0.00544450869825336


Epoch 2/5: 100%|██████████| 210/210 [00:19<00:00, 10.96batch/s]


Epoch 2/5, Loss: 0.0011862881670822389


Epoch 3/5: 100%|██████████| 210/210 [00:18<00:00, 11.08batch/s]


Epoch 3/5, Loss: 0.0009890937444911937


Epoch 4/5: 100%|██████████| 210/210 [00:19<00:00, 11.01batch/s]


Epoch 4/5, Loss: 0.0009526315573090133


Epoch 5/5: 100%|██████████| 210/210 [00:19<00:00, 10.83batch/s]


Epoch 5/5, Loss: 0.0009496307073623895


<br>
<br>
<br>

In [89]:
specialty_unique_doctors = pd.read_csv("./unique_specialty.csv")

specialty_unique_doctors.head()

specialty_unique = np.array(specialty_unique_doctors["SPECIALTY"].values, dtype="U10").flatten()
specialty_unique = [str(sp) for sp in specialty_unique]

In [96]:
def knn(text, device, k):

    non_trainable_model.eval()
    non_trainable_model.to(device)

    

    # Compute embeddings for all specialties/doctors
    with torch.no_grad():
        doctor_embeddings = non_trainable_model.encode(specialty_unique)  # Shape: [num_doctors, embedding_dim]
        doctor_embeddings = torch.tensor(doctor_embeddings).to(device)


    trainable_model.eval()
    trainable_model.to(device)
    
    # Compute the embedding for the input text
    with torch.no_grad():
        text_embedding = trainable_model([text])  # Shape: [1, embedding_dim]
        text_embedding = torch.tensor(text_embedding).to(device)
    

    # Calculate the differences and compute the norms
    distances = torch.norm(doctor_embeddings - text_embedding, dim=1)  # Shape: [num_doctors]
    
    # Find the indices of the k smallest distances
    nearest_indices = torch.topk(distances, k, largest=False).indices
    nearest_distances = distances[nearest_indices]
    
    # Retrieve the corresponding doctors/specialties
    nearest_doctors = [specialty_unique[i] for i in nearest_indices.cpu().numpy()]

    return list(zip(nearest_doctors, nearest_distances.cpu().numpy()))

In [103]:
query = "Acute bronchitis (disorder)"

nearest_doctors = knn(text = query, device = torch.device("mps"), k = 5)

# Print the results
print("Nearest doctors/specialties:")
for doctor, distance in nearest_doctors:
    print(f"{doctor}: {distance}")

Nearest doctors/specialties:
AUDIOLOGIS: 5.4151930809021
OTOLARYNGO: 6.006369113922119
NEPHROLOGY: 6.435783863067627
CARDIOVASC: 6.480925559997559
GYNECOLOGI: 6.584884166717529


  text_embedding = torch.tensor(text_embedding).to(device)
