In [30]:
import pandas as pd
import spacy
import sklearn_crfsuite
import os
import joblib # Zum Speichern der trainierten Modelle
from sklearn.model_selection import train_test_split
from sklearn_crfsuite import metrics
from ast import literal_eval
from tqdm import tqdm

In [31]:
DATA_PATH = "../nbme-score-clinical-patient-notes/"

nlp = spacy.load("en_core_web_sm")

In [None]:
features_df = pd.read_csv(f"{DATA_PATH}features.csv")
notes_df = pd.read_csv(f"{DATA_PATH}patient_notes.csv")
train_df_raw = pd.read_csv(f"{DATA_PATH}train.csv")

df = train_df_raw.merge(notes_df, on=["pn_num", "case_num"])
df = df.merge(features_df, on=["feature_num", "case_num"])

In [33]:
def get_bio_tags_multi_feature(text, locations, feature_name):
    doc = nlp(text)
    tags = ['O'] * len(doc)
    valid_locations = literal_eval(locations)
    
    tag_suffix = feature_name.replace(' ', '_').replace('-', '_').upper()

    for loc_str in valid_locations:
        spans = loc_str.split(';')
        for span in spans:
            try:
                start_char, end_char = [int(x) for x in span.split()]

                is_first_token = True
                for i, token in enumerate(doc):
                    token_start = token.idx
                    token_end = token.idx + len(token.text)

                    if token_start >= start_char and token_end <= end_char:
                        if is_first_token:
                            tags[i] = f'B-{tag_suffix}' # e.g., 'B-CHEST_PRESSURE'
                            is_first_token = False
                        else:
                            tags[i] = f'I-{tag_suffix}' # e.g., 'I-CHEST_PRESSURE'
            except ValueError:
                continue
    return [(token.text, tags[i]) for i, token in enumerate(doc)]

In [34]:
def word2features(doc, i):
    word = doc[i]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower_,
        'word.suffix': word.suffix_,
        'word.shape': word.shape_,
        'word.is_alpha': word.is_alpha,
        'word.is_digit': word.is_digit,
        'word.is_stop': word.is_stop,
        'pos': word.pos_,
        'dep': word.dep_,
    }
    if i > 0:
        prev_word = doc[i-1]
        features.update({
            '-1:word.lower()': prev_word.lower_,
            '-1:word.is_stop': prev_word.is_stop,
            '-1:pos': prev_word.pos_,
        })
    else:
        features['BOS'] = True

    if i < len(doc)-1:
        next_word = doc[i+1]
        features.update({
            '+1:word.lower()': next_word.lower_,
            '+1:word.is_stop': next_word.is_stop,
            '+1:pos': next_word.pos_,
        })
    else:
        features['EOS'] = True
        
    return features

In [35]:
def sent2features(doc):
    return [word2features(doc, i) for i in range(len(doc))]

In [36]:
def get_spans_from_bio_tags(doc, bio_tags):
    spans = []
    current_span_start = -1
    
    for i, (token, tag) in enumerate(zip(doc, bio_tags)):
        if tag == 'B-LOC':
            if current_span_start != -1:
                spans.append(f"{current_span_start} {doc[i-1].idx + len(doc[i-1].text)}")
            current_span_start = token.idx
        elif tag == 'I-LOC':
            if current_span_start == -1: 
                current_span_start = token.idx
        elif tag == 'O':
            if current_span_start != -1:
                spans.append(f"{current_span_start} {doc[i-1].idx + len(doc[i-1].text)}")
                current_span_start = -1

    if current_span_start != -1:
        spans.append(f"{current_span_start} {doc[-1].idx + len(doc[-1].text)}")
            
    return spans

In [37]:
def calculate_jaccard_score(true_spans_list, pred_spans_list):
    def parse_spans_to_set(span_list):
        parsed = set()
        for s in span_list:
            parts = s.split()
            if len(parts) == 2:
                try:
                    parsed.add((int(parts[0]), int(parts[1])))
                except ValueError:
                    continue
        return parsed

    set_true = parse_spans_to_set(true_spans_list)
    set_pred = parse_spans_to_set(pred_spans_list)

    if not set_true and not set_pred:
        return 1.0 
    if not set_true or not set_pred:
        return 0.0 

    intersection = len(set_true.intersection(set_pred))
    union = len(set_true.union(set_pred))

    return intersection / union

In [42]:
MODEL_DIR = "crf_models"
os.makedirs(MODEL_DIR, exist_ok=True)

results = []
all_jaccard_scores_flat = [] 
total_test_samples = 0

unique_feature_nums = df['feature_num'].unique()
for feature_num in tqdm(unique_feature_nums, desc="Trainiere CRF-Modelle für Features"):
    
    feature_name = df[df['feature_num'] == feature_num]['feature_text'].iloc[0]

    feature_df = df[df['feature_num'] == feature_num].copy()
    
    X = []
    y_bio = []
    original_docs = []
    original_locations_str = []

    for _, row in feature_df.iterrows():
        text = row['pn_history']
        locations = row['location']
        
        doc = nlp(text)
        
        tagged_tokens = get_bio_tags(text, locations)
        labels = [label for _, label in tagged_tokens]
        
        X.append(sent2features(doc))
        y_bio.append(labels)
        original_docs.append(doc)
        original_locations_str.append(locations)
    
    if len(X) < 2:
        results.append({
            'Feature_Num': feature_num,
            'Feature_Name': feature_name,
            'F1_Score': 'N/A',
            'Jaccard_Score': 'N/A'
        })
        continue

    X_train, X_test, y_bio_train, y_bio_test, \
    docs_train, docs_test, locations_train_str, locations_test_str = \
        train_test_split(X, y_bio, original_docs, original_locations_str, test_size=0.2, random_state=42)

    crf = sklearn_crfsuite.CRF(
        algorithm='lbfgs',
        c1=0.1,
        c2=0.1,
        max_iterations=100,
        all_possible_transitions=True
    )
    crf.fit(X_train, y_bio_train)

    # Evaluation
    labels_for_metric = list(crf.classes_)
    if 'O' in labels_for_metric:
        labels_for_metric.remove('O')

    y_pred_bio = crf.predict(X_test)
    
    f1_score = metrics.flat_f1_score(y_bio_test, y_pred_bio, average='weighted', labels=labels_for_metric)
    
    # Jaccard Score
    jaccard_scores_for_feature = []
    
    for i in range(len(X_test)):
        true_spans = []
        try:
            loc_str_list = literal_eval(locations_test_str[i])
            for loc_entry in loc_str_list:
                true_spans.extend(loc_entry.split(';'))
            true_spans = [s.strip() for s in true_spans if s.strip()]
        except (ValueError, SyntaxError):
            true_spans = []

        pred_spans = get_spans_from_bio_tags(docs_test[i], y_pred_bio[i])
        
        score = calculate_jaccard_score(true_spans, pred_spans)
        jaccard_scores_for_feature.append(score)
        all_jaccard_scores_flat.append(score) 

    avg_jaccard_score = sum(jaccard_scores_for_feature) / len(jaccard_scores_for_feature) if jaccard_scores_for_feature else 0.0
    total_test_samples += len(X_test)

    results.append({
        'Feature_Num': feature_num,
        'Feature_Name': feature_name,
        'F1_Score': f1_score,
        'Jaccard_Score': avg_jaccard_score
    })

    model_filename = os.path.join(MODEL_DIR, f"crf_model_feature_{feature_num}.joblib")
    joblib.dump(crf, model_filename)

print("\n Training ended.")

# Calculate global average Jaccard Score
global_avg_jaccard_score = sum(all_jaccard_scores_flat) / len(all_jaccard_scores_flat) if all_jaccard_scores_flat else 0.0


print("\n---")
print("Summary of Results:")
print("---\n")
results_df = pd.DataFrame(results)
print(results_df.to_string())

print("\n---")
print(f"Global Average Jaccard Score: {global_avg_jaccard_score:.4f}")
print("---\n")

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Trainiere CRF-Modelle für Features: 100%|██████████| 143/143 [13:11<00:00,  5.54s/it]


Alle Trainingsprozesse abgeschlossen.

---
Zusammenfassung der Modellleistung pro Feature:
---

     Feature_Num                                                          Feature_Name  F1_Score  Jaccard_Score
0              0       Family-history-of-MI-OR-Family-history-of-myocardial-infarction  0.806908       0.600000
1              1                                    Family-history-of-thyroid-disorder  0.818671       0.675000
2              2                                                        Chest-pressure  0.869560       0.850000
3              3                                                 Intermittent-symptoms  0.512668       0.291667
4              4                                                           Lightheaded  0.405714       0.550000
5              5      No-hair-changes-OR-no-nail-changes-OR-no-temperature-intolerance  0.603733       0.583333
6              6                                                          Adderall-use  0.745192       0.675000
7      


