# Few-shot Classifiction of SDoH

## 0. Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
import pandas as pd
import os
from pathlib import Path
import sys
from IPython.display import display, HTML

# Add the project root to the Python path to import the modules
project_root = Path().absolute().parent
sys.path.append(str(project_root))

In [16]:
import torch
import transformers

# Use shared cache
os.environ['HF_HOME'] = '/data/resource/huggingface'
os.environ['TRANSFORMERS_OFFLINE'] = '1'  # Force offline mode

# What models are available
cache_dir = "/data/resource/huggingface/hub"
available_models = []

# Suppress warnings from transformers
transformers.logging.set_verbosity_error()

In [3]:
# Load cleaned data
brc_referrals_cleaned = pd.read_csv("../data/processed/brc-cleaned/referrals_cleaned.csv")

## 1. Few-shot classification of SDoH

### 1.1 Loading the models

In [17]:
if os.path.exists(cache_dir):
    for item in os.listdir(cache_dir):
        if item.startswith("models--"):
            # Convert models--org--name to org/name format
            model_name = item.replace("models--", "").replace("--", "/")
            available_models.append(model_name)

print("Available cached models:")
for model in sorted(available_models):
    print(f"  {model}")

Available cached models:
  CohereForAI/aya-23-35B
  CohereForAI/aya-23-8B
  CohereForAI/aya-vision-8b
  HuggingFaceTB/SmolLM-135M-Instruct
  LLaMAX/LLaMAX3-8B-Alpaca
  Qwen/Qwen1.5-4B
  Qwen/Qwen2-7B
  Qwen/Qwen2.5-1.5B
  Qwen/Qwen2.5-3B
  Qwen/Qwen2.5-72B-Instruct
  Qwen/Qwen2.5-7B
  Qwen/Qwen2.5-7B-Instruct
  Qwen/Qwen2.5-7B-instruct
  Qwen/Qwen2.5-VL-7B-Instruct
  Qwen/Qwen3-0.6B
  Qwen/Qwen3-8B
  Unbabel/wmt20-comet-qe-da
  Unbabel/wmt22-comet-da
  bert-base-uncased
  bert-large-uncased
  cardiffnlp/twitter-roberta-base-sentiment
  clairebarale/refugee_cases_ner
  cross-encoder/stsb-roberta-base
  deepseek-ai/DeepSeek-R1-Distill-Llama-70B
  deepseek-ai/DeepSeek-R1-Distill-Llama-8B
  deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
  deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
  deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
  facebook/nllb-200-3.3B
  facebook/nllb-200-distilled-1.3B
  facebook/nllb-200-distilled-600M
  google/gemma-3-1b-it
  google/gemma-3-27b-it
  google/gemma-3-27b-it-qat-q4_0-ggu

In [4]:
print(f"Transformers version: {transformers.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

Transformers version: 4.52.3
PyTorch version: 2.6.0
CUDA available: True


In [4]:
# Load one of the instruction-tuned models
# Qwen/Qwen2.5-7B-Instruct
# meta-llama/Llama-3.1-8B-Instruct
# microsoft/Phi-4-mini-instruct
# mistralai/Mistral-7B-Instruct-v0.3

from src.classification.model_helpers import load_instruction_model

model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer, model = None, None

tokenizer, model = load_instruction_model(model_name)



Loading meta-llama/Llama-3.1-8B-Instruct...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ meta-llama/Llama-3.1-8B-Instruct loaded successfully!


### 1.2 Extraction from one note

In [5]:
# Load a specific note: Case Reference = CAS-467812
sample_note = brc_referrals_cleaned[brc_referrals_cleaned['Case Reference'] == 'CAS-467812'].iloc[0]['Referral Notes (depersonalised)']

In [6]:
from src.classification.prompt_creation_helpers import create_automated_prompt

prompt_example_basic = create_automated_prompt("This is a sentence", tokenizer=tokenizer, prompt_type="five_shot_basic")
print("=" * 50)
print("Example Prompt (Five Shot Basic):")
print("=" * 50)
print(prompt_example_basic)

Example Prompt (Five Shot Basic):
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are analyzing a referral note sentence to identify Social Determinants of Health, and classifying them as Adverse or Protective.

Given a sentence, output all SDoH factors that can be inferred from that sentence from the following list: 
Loneliness, Housing, Finances, FoodAccess, Digital, Employment, EnglishProficiency.

Each SDoH must be classified as either "Adverse" or "Protective". 
If the sentence does NOT mention any of the above categories, output <LIST>NoSDoH</LIST>.

Your response must be a comma-separated list of SDoH-Polarity pairs embedded in <LIST> and </LIST> tags.

**STRICT RULES**:
- DO NOT generate any other text, explanations, or new SDoH labels.
- A sentence CAN be labeled with one or more SDoH factors.
- The only accepted format is <LIST>...</LIST>.

EXAMPLES:
Input: "She is unemployed and struggles to pay 

In [7]:
from src.classification.SDoH_classification_helpers import SDoHExtractor

# Initialize the SDoH extractor
extractor = SDoHExtractor(
    model=model,
    tokenizer=tokenizer,
    prompt_type="five_shot_basic",
    debug=True,
)

# Extract SDoH factors
results = extractor.extract_from_note(sample_note)
results_df = extractor.results_to_dataframe(results, note_id="sample")

print("\nExtracted SDoH Factors:")
display(results_df)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Extracted SDoH Factors:


Unnamed: 0,note_id,sentence_number,sentence,has_sdoh,sdoh_factors,num_sdoh_factors
0,sample,1,Lives with husband for whom patient is carer,True,"Housing-Protective, Employment-Adverse",2
1,sample,2,Living on ready meals at present,True,FoodAccess-Adverse,1
2,sample,3,[PERSON] concerned that they may not be eating...,True,FoodAccess-Adverse,1
3,sample,4,Carers in [REDACTED] times daily for patient t...,True,"Housing-Protective, Employment-Adverse",2
4,sample,5,"Depending on side - effects of radiotherapy , ...",True,FoodAccess-Adverse,1
5,sample,6,Patient feeling slightly overwhelmed by everyt...,True,Loneliness-Adverse,1
6,sample,7,FPOC and Carers Support Shropshire numbers giv...,True,"Loneliness-Adverse, Finances-NoSDoH, Housing-N...",6
7,sample,8,Very supportive daughter who lives in [PERSON],True,"Loneliness-Protective, Housing-Adverse",2
8,sample,9,The patient is due to start radiotherapy on [R...,False,NoSDoH,0
9,sample,10,Due to start radiotherapy on [REDACTED] at SAT...,False,NoSDoH,0


In [9]:
results_df.head()

Unnamed: 0,note_id,sentence_number,sentence,has_sdoh,sdoh_factors,num_sdoh_factors
0,sample,1,Lives with husband for whom patient is carer,True,"Housing-Protective, Employment-Adverse",2
1,sample,2,Living on ready meals at present,True,FoodAccess-Adverse,1
2,sample,3,[PERSON] concerned that they may not be eating...,True,FoodAccess-Adverse,1
3,sample,4,Carers in [REDACTED] times daily for patient t...,True,"Housing-Protective, Employment-Adverse",2
4,sample,5,"Depending on side - effects of radiotherapy , ...",True,FoodAccess-Adverse,1


In [None]:
# Some debugging
print("Prompt: \n")
print(results['sentences'][1]['debug']['prompt'])

print("Raw response: \n")
print(results['sentences'][1]['debug']['raw_response'])

### 1.3. Extracting from multiple notes (batch processing) and evaluating few-shot extraction

In [10]:
# Set desired model and prompt config
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
PROMPT_TYPE = "five_shot_basic"

# Load model and tokenizer
from src.classification.model_helpers import load_instruction_model
from src.classification.SDoH_classification_helpers import SDoHExtractor

tokenizer, model = load_instruction_model(MODEL_NAME)

# Confirm it's loaded
if tokenizer is None or model is None:
    raise ValueError(f"Failed to load model: {MODEL_NAME}")

# Create extractor using your standard constructor
extractor = SDoHExtractor(
    model=model,
    tokenizer=tokenizer,
    prompt_type=PROMPT_TYPE,
    debug=False
)

Loading meta-llama/Llama-3.1-8B-Instruct...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ meta-llama/Llama-3.1-8B-Instruct loaded successfully!


In [22]:
from sklearn.metrics import classification_report, f1_score
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm

# === Step 1: Load test set ===
test_df = pd.read_csv("../data/processed/train-test/test_set.csv")
test_df["label_pair"] = test_df["label_pair"].apply(eval)

In [20]:
# === Step 2: Run model inference using extractor ===
y_true = []
y_pred = []
sentences = []

for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    sentence = row["Sentence"]
    gold = sorted(row["label_pair"])
    
    result = extractor.extract_from_sentence(sentence)
    pred = sorted(result["sdoh_factors"])  # list of predicted labels
    
    y_true.append(gold)
    y_pred.append(pred)
    sentences.append(sentence)

# === Step 3: Binarize for multilabel metrics ===
mlb = MultiLabelBinarizer()
y_true_bin = mlb.fit_transform(y_true)
y_pred_bin = mlb.transform(y_pred)  # must not refit — only transform

# === Step 4: Print F1 scores ===
print("Few-Shot Classification Report:\n")
print(classification_report(y_true_bin, y_pred_bin, target_names=mlb.classes_))

# === Step 5: Save CSV for manual inspection ===
eval_results_df = pd.DataFrame({
    "Sentence": sentences,
    "Gold Labels": [", ".join(lbls) for lbls in y_true],
    "Predicted Labels": [", ".join(lbls) for lbls in y_pred],
    "Exact Match": [set(t) == set(p) for t, p in zip(y_true, y_pred)]
})

eval_results_df.to_csv("../results/eval/few_shot_eval_30_06.csv", index=False)
print("\nSaved evaluation results to: ../results/eval/few_shot_eval_30_06.csv")

100%|██████████| 243/243 [02:17<00:00,  1.76it/s]

Few-Shot Classification Report:

                       precision    recall  f1-score   support

      Digital-Adverse       0.40      0.67      0.50         6
   Digital-Protective       0.00      0.00      0.00         1
   Employment-Adverse       0.17      1.00      0.29         3
Employment-Protective       0.00      0.00      0.00         1
      English-Adverse       0.00      0.00      0.00         2
     Finances-Adverse       0.38      0.82      0.52        17
  Finances-Protective       0.00      0.00      0.00         1
         Food-Adverse       0.00      0.00      0.00        20
      Food-Protective       0.00      0.00      0.00         1
      Housing-Adverse       0.44      0.71      0.55        28
   Housing-Protective       0.12      1.00      0.22         1
   Loneliness-Adverse       0.57      0.69      0.63        39
Loneliness-Protective       0.50      0.57      0.53         7
               NoSDoH       0.86      0.80      0.83       144

            micro av


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [21]:
# Generate report as dict
report_dict = classification_report(
    y_true_bin,
    y_pred_bin,
    target_names=mlb.classes_,
    output_dict=True
)

# Convert to DataFrame
report_df = pd.DataFrame(report_dict).transpose()

# Save to CSV
report_path = "../results/eval/few_shot_eval_report_30_06.csv"
report_df.to_csv(report_path)
print(f"\nSaved classification report to: {report_path}")


Saved classification report to: ../results/eval/few_shot_eval_report_30_06.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


The following script processes a batch of notes, based on the SDoH extractor used for a single note earlier. It includes many options that can be modified:

- model name
- prompt type
- level of classification (1 for mention of SDoH, 2 for adverse vs. protective mention)
- batch size and start index


To run it, enter the following command in the terminal, after activating the conda environment and adjusting the options:

```console
python scripts/batch_process_notes.py --model_name "meta-llama/Llama-3.1-8B-Instruct" \
                                 --prompt_type "five_shot_basic" \
                                 --batch_size 10 \
                                 --start_index 0
```

We can also evaluate models on the annotation dataset, this is done using another script:

```console
python scripts/evaluate_on_annotations.py --model_name "meta-llama/Llama-3.1-8B-Instruct" \
                                  --prompt_type "five_shot_basic" \
                                  --annotation_data "data/raw/BRC-Data/annotated_BRC_referrals.csv" \
                                  --sample_size 5
```

The evaluation system has two main components:

- **Main evaluation script** (`evaluate_on_annotations.py`) -- This is the orchestrator that runs everything;
- **Supporting utilies** (`utils/evaluation_helpers_lvl1.py`)-- These handle the specific tasks like model loading, SDoH extraction, and metric calculation.

The evaluation script does the following for each sentence:

1. **Extract**: Run the sentence through the SDoHExtractor
2. **Format**: Convert the model's list output to a comparable string
3. **Record**: Store both human and model labels plus metadata
4. **Structure**: Build a DataFrame ready for multi-label metrics calculation

Demonstration:

In [None]:
from utils.SDoH_classification_helpers import SDoHExtractor
from utils.model_helpers import load_instruction_model

# Load annotated data (first few rows for demo)
annotated_df = pd.read_csv("../data/raw/BRC-Data/annotated_BRC_referrals.csv")
sample_df = annotated_df.head(5)  # Just 5 sentences for demo
sample_df.columns = ['CAS', 'Sentence', 'Label', 'Adverse', 'Comments']  # Standardise column names

print("Sample annotated data:")
print(sample_df[['Sentence', 'Label']])

# Load model
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer, model = load_instruction_model(model_name)

# Create extractor
extractor = SDoHExtractor(
    model=model,
    tokenizer=tokenizer,
    prompt_type="five_shot_basic",
    level=1,
    debug=True 
)

# Process each sentence and build results
results = []
for idx, row in sample_df.iterrows():
    sentence = str(row['Sentence']).strip()
    
    # Extract SDoH from sentence
    extraction_result = extractor.extract_from_sentence(sentence)
    factors = extraction_result["sdoh_factors"]
    
    # Convert to comparison format
    model_prediction = ", ".join(sorted(factors)) if factors and factors != ["NoSDoH"] else "NoSDoH"
    
    # Create result record (same structure as your evaluation script)
    result = {
        'cas': row['CAS'],
        'sentence_number': idx + 1,
        'original_sentence': sentence,
        'original_label': row['Label'],
        'model_prediction': model_prediction,
        'model_factors_list': factors,
        'model_has_sdoh': factors != ["NoSDoH"] and bool(factors),
        'num_model_factors': len(factors) if factors != ["NoSDoH"] else 0,
    }
    
    results.append(result)
    
    # Show what happened
    print(f"\n--- Sentence {idx + 1} ---")
    print(f"Text: {sentence[:60]}...")
    print(f"Human labeled: {row['Label']}")
    print(f"Model predicted: {model_prediction}")
    if extraction_result.get("debug"):
        print(f"Raw model response: {extraction_result['debug']['raw_response']}")

# Convert to DataFrame (ready for metrics calculation)
results_df = pd.DataFrame(results)

Sample annotated data:
                                            Sentence  \
0  She needs help with food , toiletry and some cash   
1  Mr PERSON was having support of a friend who h...   
2  Isolated , housing concern impacting MH SU pre...   
3      Equipment delivery to ensure safer discharge.   
4  XXXX shopping FBG Patient is no longer driving...   

                                Label  
0  FoodInsecurity, FinancialSituation  
1                       SocialSupport  
2              SocialSupport, Housing  
3                              NoSDoH  
4                      Transportation  
Loading meta-llama/Llama-3.1-8B-Instruct...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


✓ meta-llama/Llama-3.1-8B-Instruct loaded successfully!
Using chat template for meta-llama/llama-3.1-8b-instruct


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- Sentence 1 ---
Text: She needs help with food , toiletry and some cash...
Human labeled: FoodInsecurity, FinancialSituation
Model predicted: FinancialSituation, FoodInsecurity
Raw model response: <LIST>FinancialSituation, FoodInsecurity</LIST>
Using chat template for meta-llama/llama-3.1-8b-instruct


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- Sentence 2 ---
Text: Mr PERSON was having support of a friend who had a car accid...
Human labeled: SocialSupport
Model predicted: SocialSupport, Transportation
Raw model response: <LIST>SocialSupport, Transportation</LIST>
Using chat template for meta-llama/llama-3.1-8b-instruct


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- Sentence 3 ---
Text: Isolated , housing concern impacting MH SU previously suppor...
Human labeled: SocialSupport, Housing
Model predicted: Housing, SocialSupport, SubstanceUse
Raw model response: <LIST>Housing, SocialSupport, SubstanceUse</LIST>
Using chat template for meta-llama/llama-3.1-8b-instruct


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- Sentence 4 ---
Text: Equipment delivery to ensure safer discharge....
Human labeled: NoSDoH
Model predicted: NoSDoH
Raw model response: <LIST>NoSDoH</LIST>
Using chat template for meta-llama/llama-3.1-8b-instruct

--- Sentence 5 ---
Text: XXXX shopping FBG Patient is no longer driving , therefore n...
Human labeled: Transportation
Model predicted: FinancialSituation, SocialSupport, Transportation
Raw model response: <LIST>Transportation, FinancialSituation, SocialSupport</LIST>

Results DataFrame:
                       original_label  \
0  FoodInsecurity, FinancialSituation   
1                       SocialSupport   
2              SocialSupport, Housing   
3                              NoSDoH   
4                      Transportation   

                                    model_prediction  model_has_sdoh  
0                 FinancialSituation, FoodInsecurity            True  
1                      SocialSupport, Transportation            True  
2               Housing, SocialS

In [7]:
results_df

Unnamed: 0,cas,sentence_number,original_sentence,original_label,model_prediction,model_factors_list,model_has_sdoh,num_model_factors
0,CAS-548353,1,"She needs help with food , toiletry and some cash","FoodInsecurity, FinancialSituation","FinancialSituation, FoodInsecurity","[FinancialSituation, FoodInsecurity]",True,2
1,CAS-548411,2,Mr PERSON was having support of a friend who h...,SocialSupport,"SocialSupport, Transportation","[SocialSupport, Transportation]",True,2
2,CAS-548427,3,"Isolated , housing concern impacting MH SU pre...","SocialSupport, Housing","Housing, SocialSupport, SubstanceUse","[Housing, SocialSupport, SubstanceUse]",True,3
3,CAS-548530,4,Equipment delivery to ensure safer discharge.,NoSDoH,NoSDoH,[NoSDoH],False,0
4,CAS-548590,5,XXXX shopping FBG Patient is no longer driving...,Transportation,"FinancialSituation, SocialSupport, Transportation","[Transportation, FinancialSituation, SocialSup...",True,3


After classifying SDoH from annotated sentences, the `calculate_multilabel_metrics` function from `utils/evaluation_helpers.py` does three main steps:

1. Label parsing & preparation
2. Binary matrix conversion

    Example: If the labels are ["Housing", "Employment", "Social Support"]:
    - ["Housing", "Employment"] becomes [1, 1, 0]
    - ["Housing"] becomes [1, 0, 0]
    - ["NoSDoH"] becomes [0, 0, 0]

3. Return multi-Label metrics

    The function calculates four types of metrics:
    - Example-based: How well does the model predict the exact set of labels for each sentence?
    - Label-based: How well does the model perform on each individual label?
    - Per-label: Performance breakdown for each SDoH factor
    - Statistics: Overall dataset characteristics

I can now dive deeper into the metrics used for multi-label classification.

1. **Example-Based Metrics (averaged across sentences)**. These look at how well the model predicts the complete set of labels for each sentence:

    - Exact Match Ratio (Subset Accuracy): Percentage where model prediction exactly matches human annotation. ["Housing", "Employment"] == ["Housing", "Employment"] ✓ = 1.0; ["Housing"] vs ["Housing", "Employment"] ✗ = 0.0
    - Hamming Loss: Fraction of wrong label assignments (lower is better). Counts individual label mistakes across all positions; If you have 3 possible labels and get 1 wrong: hamming_loss = 1/3 = 0.33.
    - Additional metrics: Example-based Precision/Recall/F1; Jaccard Index
2. **Label-Based Metrics (averaged across labels)**. These treat each SDoH factor as a separate binary classification problem:

    - Macro-averaged (treats all labels equally). Calculate precision/recall/F1 for each label separately; and average them (rare labels get same weight as common ones)
    - Micro-averaged (weighted by frequency). Pool all true/false positives across labels; more influenced by common labels.

3. **Per-Label Performance**. Individual precision/recall/F1 for each SDoH factor:
4. **Dataset Statistics**

    - Label Cardinality: Average number of labels per sentence
    - Label Density: Cardinality divided by total possible labels
    - Coverage: How many different labels appear at least once