In [58]:
from huggingface_hub import login
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

def classify_clinical_cases_with_llm(df, text_column, batch_size=16):
    """
    Classify clinical cases as gendered or ungendered using a large language model with optimized batch processing.

    Args:
        df (pd.DataFrame): DataFrame containing clinical case text.
        text_column (str): Column name containing the clinical text.
        batch_size (int): Number of rows to process in a batch.

    Returns:
        pd.DataFrame: DataFrame with a new column 'case_type' indicating 'gendered' or 'ungendered'.
    """
    hf_token = "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"
    login(hf_token, add_to_git_credential=True)

    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    # Load model and tokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def classify_batch(texts):
        # Create prompts for the batch
        prompts = [
            f"Classify the following clinical case as 'gendered' or 'ungendered'. "
            f"If the text refers to gender-specific conditions or anatomy (e.g., genitalia, ovarian cancer, pregnancy), "
            f"it should be marked as 'gendered'. Otherwise, mark it as 'ungendered'.\n\n"
            f"Clinical text: {text}\n"
            f"Case type:"
            for text in texts
        ]

        # Tokenize and process the batch
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = model.generate(**inputs, max_new_tokens=50)

        # Decode responses
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        classifications = []
        for response in responses:
            if " gendered" in response.lower():
                classifications.append("gendered")
            elif " ungendered" in response.lower():
                classifications.append("ungendered")
            else:
                classifications.append("ungendered")
        return classifications

    # Process in batches
    texts = df[text_column].tolist()
    case_types = []
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i + batch_size]
        case_types.extend(classify_batch(batch_texts))

    # Add results to DataFrame
    df['case_type'] = case_types
    return df

In [59]:
df = pd.read_csv('craft-md/data/usmle_and_derm_dataset.csv')
result_df = classify_clinical_cases_with_llm(df, 'case_vignette', batch_size=2)
print(result_df)


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.97it/s]
  0%|          | 0/1002 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 1/1002 [00:01<23:53,  1.43s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 2/1002 [00:02<23:46,  1.43s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 3/1002 [00:04<23:45,  1.43s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 4/1002 [00:05<23:43,  1.43s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 5/1002 [00:07<23:32,  1.42s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 6/1002 [00:08<23:24,  1.41s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 7/1002 [00:09<23:20,  1.41s/it]Setting `pad_token_id` to `eos_token_id`

      Unnamed: 0                                      case_vignette  \
0              0  A 22-year-old man presented with complaints of...   
1              1  A 51-year-old white man presents to the clinic...   
2              2  An 8-year-old male child is brought to the cli...   
3              3  A 24-year-old white female graduate student pr...   
4              4  A 20-year-old woman presents to the clinic wit...   
...          ...                                                ...   
1999        3879  A previously healthy 28-year-old man presents ...   
2000        3893  A 21-year-old male presents to the ED with a s...   
2001        3896  A 4-year-old boy presents to the opthalmologis...   
2002        3897  A 55-year-old female is hospitalized following...   
2003        3912  A 76-year-old woman presents to her primary ca...   

                                               choice_1  \
0                              Lymphogranuloma venereum   
1                  Acrodermat




In [60]:
result_df

Unnamed: 0.1,Unnamed: 0,case_vignette,choice_1,choice_2,choice_3,choice_4,answer,category,dataset,case_id,case_type
0,0,A 22-year-old man presented with complaints of...,Lymphogranuloma venereum,Herpes,Chancroid,Syphilis,Lymphogranuloma venereum,Dermatology,dermatology_public,case_0,gendered
1,1,A 51-year-old white man presents to the clinic...,Acrodermatitis continua of Hallopeau,Herpetic whitlow,Paronychia,Acute contact dermatitis,Acute contact dermatitis,Dermatology,dermatology_public,case_1,ungendered
2,2,An 8-year-old male child is brought to the cli...,Varicella,Gianotti-Crosti syndrome,Langerhans cell histiocytosis,Pityriasis lichenoides et varioliformis acuta,Pityriasis lichenoides et varioliformis acuta,Dermatology,dermatology_public,case_2,ungendered
3,3,A 24-year-old white female graduate student pr...,Perioral dermatitis,Acne vulgaris,Allergic contact dermatitis,Facial demodicosis,Perioral dermatitis,Dermatology,dermatology_public,case_3,ungendered
4,4,A 20-year-old woman presents to the clinic wit...,Halo nevus,Melanoma,Vitiligo,Dysplastic nevus,Halo nevus,Dermatology,dermatology_public,case_4,ungendered
...,...,...,...,...,...,...,...,...,...,...,...
1999,3879,A previously healthy 28-year-old man presents ...,Figure A,Figure C,Figure B,Figure E,Figure B,Gastrointestinal System,MedQA_USMLE,case_3879,ungendered
2000,3893,A 21-year-old male presents to the ED with a s...,"Right-sided tactile, vibration, and propriocep...","Left-sided tactile, vibration, and propriocept...","Right-sided tactile, vibration, and propriocep...","Right-sided tactile, vibration, and propriocep...","Right-sided tactile, vibration, and propriocep...",Other,MedQA_USMLE,case_3893,ungendered
2001,3896,A 4-year-old boy presents to the opthalmologis...,Homocystinuria,Marfan syndrome,Maple syrup disease,Phenylketonuria,Homocystinuria,Other,MedQA_USMLE,case_3896,gendered
2002,3897,A 55-year-old female is hospitalized following...,Gram-positive bacterial infection,Viral infection,Adrenal insufficiency,Gram-negative bacterial infection,Gram-negative bacterial infection,Other,MedQA_USMLE,case_3897,ungendered


In [2]:
import pandas as pd

df = pd.read_csv('craft-md/data_augmented/baseline.csv')
result_df = df[df['case_type']=='ungendered']

result_df.to_csv('craft-md/data_augmented/baseline_ungendered.csv')

# Gender Swapping

In [5]:
import os
from openai import AzureOpenAI
import pandas as pd 
import numpy as np
import argparse
import random
import json
from tqdm import tqdm
import datetime
import string
import math
import sys 

def random_perturbations(text, type_pert, prob):

  random.seed(0)
  assert type(text) == str
  list_chars = list(text)

  if type_pert=="lowercase":
    for i in range(len(list_chars)):
      char_value = random.choices([list_chars[i].lower(), list_chars[i]], weights=[prob, 1-prob], k=1)[0]
      list_chars[i] = char_value

  if type_pert=="uppercase":
    for i in range(len(list_chars)):
      char_value = random.choices([list_chars[i].upper(), list_chars[i]], weights=[prob, 1-prob], k=1)[0]
      list_chars[i] = char_value

  if type_pert=="exclamation":
    indices = [i for i, letter in enumerate(list_chars) if letter == "."]
    for i in indices:
      char_value = random.choices(["!", "."], weights=[prob, 1-prob], k=1)[0]
      list_chars[i] = char_value
  
  if type_pert == "typo":
    # Get the indices of all non-space characters
    nonspace_indices = [i for i, char in enumerate(list_chars) if not char.isspace()]

    # Calculate the number of indices to flip based on the probability
    num_indices = math.floor(len(nonspace_indices) * prob)
    
    # Randomly select indices to flip
    flipping_indices = random.sample(nonspace_indices, k=num_indices)
    
    # Perform the flips
    for i in flipping_indices:
      # Randomly choose a replacement character from the alphabet
      list_chars[i] = random.choice(string.ascii_letters)

  if type_pert == "whitespace":
    new_text = []
    for char in list_chars:
        # Randomly add whitespace before the character
        add_space = random.choices([True, False], weights=[prob, 1-prob], k=1)[0]
        if add_space:
            whitespace = " " * random.randint(1, 3)  # Add 1 to 3 spaces
            new_text.append(whitespace)
        new_text.append(char)
    list_chars = new_text

  return "".join(list_chars)

def regex_perturb(df, attribute, probability):
  """
  gender swap using regex 
  """
  perturbed_messages = []
  for i, row in df.iterrows(): 
    perturbed_text = random_perturbations(row['case_vignette'], attribute, probability)
    perturbed_messages.append(perturbed_text)

  df['case_vignette'] = [r for r in perturbed_messages]
  return df


In [10]:
baseline_df = pd.read_csv('/data/healthy-ml/scratch/abinitha/craft-md/craft-md/data_augmented/baseline.csv')
lowercase_df = regex_perturb(baseline_df, 'typo', 0.03)
lowercase_df.to_csv('/data/healthy-ml/scratch/abinitha/craft-md/craft-md/data_augmented/typo.csv')