<a href="https://colab.research.google.com/github/24p11/recode-with-mistral-finetune/blob/main/tutorials/prepare_data_for_generative_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Predict code ICD with intruct generative model 

## Data preparation

## Objective

Build jsonl files for the learning of mistral-instruct 7B model.

#### format instruct 

Given the system context and user question, the model try to predict the assistant response 

- system : Vous êtes un modèle de langage en française spécialisé dans le codage des diagnostics selon la classification internationale des maladies version 10 (CIM-10) pour les résumés standardisés de sortie du programme de médicalisation des systèmes d'information français (PMSI). A partir des comptes rendus d'hospitalisation vous donnerez les codes diagnostics CIM-10 que l'on peut retenir pour le séjours en distiguant diagnostic principal, diagnostic relié et diagnostics associés.
- user : Générer le codage CIM-10 du résumé standardisé de sortie PMSI à partir du compte rendu d'hospitalisation suivant : début du compte rendu ... fin du compte rendu 
- assistant: Codes CIM 10 retenus pour le résumé standardisé de sortie PMSI :
    - Diagnostic principal : Autres soins de contrôle chirurgicaux précisés (Z488)
    - Diagnostic relié : Aucun.
    - Diagnostics associés :
      * Autres complications d'un acte à visée diagnostique et thérapeutique, non classées ailleurs (T818),
      * Antécédents personnels d'intervention chirurgicale importante, non classée ailleurs (Z924),
      * Acidose (E872)
      
      
#### format text
Next token prediction, the model try to predict next word given the previous context.

- text : 
  * début du compte rendu ... fin du compte rendu 
  * Diagnostics CIM-10 du résumé standardisé de sortie PMSI : 
    - Diagnostic principal : Autres soins de contrôle chirurgicaux précisés (Z488)
    - Diagnostic relié : Aucun.
    - Diagnostics associés :
      * Autres complications d'un acte à visée diagnostique et thérapeutique, non classées ailleurs (T818),
      * Antécédents personnels d'intervention chirurgicale importante, non classée ailleurs (Z924),
      * Acidose (E872)
      
      
### Clinical notes
For this experience notes have been prepared to match the CSV format  (variable names inherited from omop format)
- ```visit_occurrence_id``` : encounter number
- ```text``` : text of the clinical note (texte du compte rendu d'hospitalisation)

Data have been initially prepared in 3 files :
- train.csv
- val.csv
- test.csv

### ICD 10 codes

Data are in CSV format (variable names inherited from omop format)
For ICD-10 diagnosis have been prepared in table  ```prep_raw_codes_sel``` :
- ```visit_occurrence_id``` : encounter number
- ```condition_status_source_value``` : position of diagnosis (DP,DR,DAS)
- ```concept_id``` : ID  ICD-10 dictionary
- ```concept_code``` : ICD-10 code
- ```concept_name``` : ICD-10 wording


For each visit (visit_occurrence_id) select :
- 1 diag principal. For multi UMA visits (cf NB), choose the diag most frequent in general population
- 1 related diagnosis (if present). For multi UMA choose the diag most frequent in general population 
- x associated diagnosis which will be ordered by ordre of frequency in the cohort when associated with this DP

NB: EDS data do not keep the notion of "RUM PRINCIPAL". For multi UMA visits, the DP of the visit is the most frequent DP.

1) Build frequencies reference tables
- 1-1 General frequencies of DP and DR --> freq_DP, fred_DR
- 1-1 Genaral frequencies of DAS --> freq_DAS
- 1-2 Frequencies of the association DP-DAS --> freq_DP_DAS


2) Choose only 1 DP and one DR for multi UMA base on frequencies 
- 2-1 Create table sel_DP
  - table 
    * visit_occurence_id
    * diag DP
  - Merge DP with table frequencie DP
  - choose the most frequent --> table sel_DP
- 2-2 Create table sel_DP
  - table 
    * visit_occurence_id
    * diag DR
  - Merge DP with table frequencie DP
  - choose the most frequent --> table sel_DR,
  
3) Order DAS 
- 3-1 Select all diagnosis (DP,DR,DAS) : 
  * table DAS
    - visit_occurence_id
    - code
- 3-2 Merge with tables sel_DP and sel_DR
  * table DAS
    - visit_occurence_id
    - code
    - DP
    - DR
- 3-3 Suppress diags that are already select in DP or DR
- 3-4 Merge with freq_DP_DAS and freq_DAS
- 3-5 Prepare variables diagnosis in the text format : "Definition of the code (code)"
- 3-6 Order by visit_occurence_id, freq_DP_DAS, freq_DAS
- 3-7 Concat DAS in a unique variables : "Definition of the code 1  (code 1), Definition of the code 2  (code 2),..." .
  *  At the end, we have a dataframe with one line by visit with columns:
     - visit_occurence_id
     - dp : "Definition of the code (code)"
     - dr : "Definition of the code (code)"
     - das : "Definition of the code 1  (code 1), Definition of the code 2  (code 2),..."
4) Prépare format for generative training

Final format :

text_diagnsotics =  
```
Diagnostic principal : Autres soins de contrôle chirurgicaux précisés (Z488)
Diagnostic relié : Aucun.
Diagnostics associés :
Autres complications d'un acte à visée diagnostique et thérapeutique, non classées ailleurs (T818),
Antécédents personnels d'intervention chirurgicale importante, non classée ailleurs (Z924),
Acidose (E872),
Antécédents personnels de chimiothérapie pour tumeur (Z926),
Fibrillation et flutter auriculaires, sans précision (I489),
Ablation d'un autre organe (partielle) (totale) à l'origine de réactions anormales du patient ou de complications ultérieures, sans mention d'accident au cours de l'intervention (Y836),
Ascite (R18),
Antécédents personnels d'irradiation (Z923),
Autres maladies précisées du pancréas (K868),
Autres atteintes non infectieuses précisées des vaisseaux et des ganglions lymphatiques (I898),
Pancréatites chroniques, autres et non précisées (K861+8),
Tumeur maligne d'autres parties du pancréas (C257).
```

5) Merge with from medical notes and write json final file format

For each file of clinical notes (train, test, val) merge with text diagnosis.

In [2]:
# librairies
import pandas as pd
import numpy as np
import re

#### Import ICD-1O dictionnary
- dataframe : 
   * code
   * libelle

In [3]:
PATH_ICD = "../referentials/"
df_icd = pd.read_csv(PATH_ICD +"cim_10_atih_2019.tsv", sep="\t",header=None,names=["code","aut_mco","pos","aut_ssr","lib_court","libelle"])
df_icd.code = df_icd.code.apply(lambda x: x.replace(" ",""))

In [11]:
PATH_DATA = "../sample_data/"
df_icd_codes = pd.read_csv(PATH_DATA +"codes.csv")
df_icd_codes.drop(columns = "concept_name",inplace=True)

### 1) Build frequencies reference tables

#### 1-1 dataframes freq_DP,freq_DR,freq_DAS
One dataframe with frequencies of the couple (position - code) : [freq_DP,freq_DR,freq_DAS]

In [12]:
df_freq_types_codes = df_icd_codes.groupby(["condition_status_source_value","concept_code"]).size().to_frame('freq_position').reset_index()

In [13]:
#Control if all codes are in the referential
df_freq_types_codes[~df_freq_types_codes.concept_code.isin(df_icd.code)]

Unnamed: 0,condition_status_source_value,concept_code,freq_position


In [14]:
nb_codes = len(df_freq_types_codes.concept_code.unique())
print("Nombres de codes à prédire : " + str(nb_codes))

Nombres de codes à prédire : 139


### 2) Choose only 1 DP and one DR for multi UMA base on frequencies

#### 2-1 Create table sel_DP

In [23]:
df_sel_dp = df_icd_codes[df_icd_codes.condition_status_source_value =='DP']


In [24]:
df_sel_dp = df_sel_dp.merge(df_freq_types_codes[df_freq_types_codes.condition_status_source_value=='DP'])
df_sel_dp = df_sel_dp.merge(df_freq_types_codes[df_freq_types_codes.condition_status_source_value=='DP'])
df_sel_dp.sort_values(["visit_occurence_id","freq_position"], ascending=[True,False],inplace=True)
df_sel_dp.rename(columns={'concept_code':'dp'},inplace=True)

In [25]:
df_sel_dp.shape

(200, 4)

#### 2-2 Create table sel_DR

In [26]:
df_sel_dr = df_icd_codes[df_icd_codes.condition_status_source_value =='DR']

In [28]:
df_sel_dr = df_sel_dr.merge(df_freq_types_codes[df_freq_types_codes.condition_status_source_value=='DR'])
df_sel_dr.sort_values(["visit_occurence_id","freq_position"], ascending=[True,False],inplace=True)
df_sel_dr = df_sel_dr.groupby(["visit_occurence_id"]).first().reset_index()
df_sel_dr.rename(columns={'concept_code':'dr'},inplace=True)

In [29]:
df_sel_dr = df_sel_dr.merge(df_icd[["code","libelle"]].rename(columns={'code':'dr','libelle':'libelle_dr'}))

### 3) Order DAS 
#### 3-1 Select all diagnosis (DP,DR,DAS) 

In [30]:
df_sel_das = df_icd_codes[df_icd_codes.condition_status_source_value.isin(['DP','DAS'])]

In [34]:
df_visit_id_das = df_icd_codes.visit_occurence_id[df_icd_codes.condition_status_source_value=="DAS"]

In [35]:
#Just to check
df_sel_das.visit_occurence_id.nunique()

200

In [36]:
df_sel_das.rename(columns={'concept_code':'das'},inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().rename(


####  3-2 Merge with tables sel_DP and sel_DR

In [37]:
df_sel_das.columns

Index(['visit_occurence_id', 'condition_status_source_value', 'das'], dtype='object')

In [38]:
df_sel_dp.columns

Index(['visit_occurence_id', 'condition_status_source_value', 'dp',
       'freq_position'],
      dtype='object')

In [41]:
df_dp_das = df_sel_das.merge(df_sel_dp.drop(columns={"condition_status_source_value", "freq_position"}))

In [43]:
df_dp_das.visit_occurence_id.nunique()

200

In [44]:
#As we took also DP IN df_das, column DAS will be filled with DP value even if the visit as no DAS
#Force DAS to empty for visit without DAS
df_dp_das.das[~(df_dp_das.visit_occurence_id.isin(df_visit_id_das))]=""

#### 3-3 Suppress diags that are already select in DP or DR

In [45]:
#Suppress lines where DP = DAS 
df_dp_das = df_dp_das[  ~ (df_dp_das.dp==df_dp_das.das) ]


In [47]:
#Drop duplicates du to mutli UMA visit
df_dp_das = df_dp_das.drop_duplicates(["visit_occurence_id","dp","das"])

In [49]:
df_dp_das.visit_occurence_id.nunique()

200

#### 3-4 Merge with freq_DP_DAS and freq_DAS

In [50]:
#Add frequency of couple [DP,DAS]
freq_dp_das = df_dp_das.groupby(["dp","das"]).size().to_frame('freq_dp_das').reset_index()

In [51]:
df_dp_das = df_dp_das.merge(freq_dp_das,how="left")

In [53]:
df_dp_das[df_dp_das.visit_occurence_id=="cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd"]

Unnamed: 0,visit_occurence_id,condition_status_source_value,das,dp,freq_dp_das
671,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E1190,Z5101,3
672,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E1198,Z5101,6
673,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E8718,Z5101,7
674,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,I482,Z5101,7
675,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,N185,Z5101,8
676,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E43,Z5101,8
677,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,R410,Z5101,7


#### 3-5 Prepare variables diagnosis in the text format : "Definition of the code (code)"

In [54]:
#Add libelles of ICD 10 codes
df_dp_das = df_dp_das.merge(df_icd[["code","libelle"]].rename(columns={'code':'dp','libelle':'libelle_dp'}),how='left')

In [55]:
df_dp_das = df_dp_das.merge(df_icd[["code","libelle"]].rename(columns={'code':'das','libelle':'libelle_das'}),how='left')

In [56]:
#Build the format : defintion (code)
df_dp_das.libelle_dp = df_dp_das.libelle_dp + ' (' + df_dp_das.dp +')'

In [57]:
df_dp_das.libelle_das = np.where(df_dp_das.das!="", 
                              df_dp_das.libelle_das + ' (' + df_dp_das.das +')',
                              df_dp_das.das )

In [92]:
"cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd" in df_dp_das.visit_occurence_id.unique()

True

In [60]:
df_dp_das[df_dp_das.visit_occurence_id=="cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd"]

Unnamed: 0,visit_occurence_id,condition_status_source_value,das,dp,freq_dp_das,libelle_dp,libelle_das
671,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E1190,Z5101,3,Séance d'irradiation (Z5101),"Diabète sucré de type 2 insulinotraité, sans c..."
672,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E1198,Z5101,6,Séance d'irradiation (Z5101),Diabète sucré de type 2 non insulinotraité ou ...
673,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E8718,Z5101,7,Séance d'irradiation (Z5101),"Hypoosmolarités et hyponatrémies, autres et sa..."
674,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,I482,Z5101,7,Séance d'irradiation (Z5101),Fibrillation auriculaire chronique [permanente...
675,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,N185,Z5101,8,Séance d'irradiation (Z5101),"Maladie rénale chronique, stade 5 (N185)"
676,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,E43,Z5101,8,Séance d'irradiation (Z5101),"Malnutrition protéino-énergétique grave, sans ..."
677,cd372bd9dfd0e15f1281bfb1496ff3dd3346e4dd,DAS,R410,Z5101,7,Séance d'irradiation (Z5101),"Désorientation, sans précision (R410)"


#### 3-6 Order by visit_occurence_id, freq_DP_DAS, freq_DAS



In [61]:
#Concatenate all DAS in one line.
#First arrange lines from most frequent to less frequent
df_dp_das.sort_values(["visit_occurence_id","freq_dp_das"], ascending=[True,False],inplace=True)

#### 3-7 Concat DAS in a unique variables : "Definition of the code 1  (code 1), Definition of the code 2  (code 2),..." .

In [62]:
df_codes = df_dp_das.groupby(["visit_occurence_id","libelle_dp"])['libelle_das'].apply(lambda x: ',\n'.join(x)).reset_index()

### 4) Prépare format for generative training

Add DR

In [63]:
#Prepare df_dr to merge with the other codes
df_sel_dr2 = df_sel_dr.copy()

In [64]:
df_sel_dr2.libelle_dr = df_sel_dr2.libelle_dr + ' (' + df_sel_dr2.dr +')'

In [66]:
df_codes = df_codes.merge(df_sel_dr2[['visit_occurence_id','libelle_dr']],how='left')

Merge all diags in one column

In [67]:
df_codes["diag"] = np.where(df_codes.libelle_dr.isna(),
                   np.where(df_codes.libelle_das=="","Diagnostic principal : " + df_codes.libelle_dp +".\nDiagnostic relié : Aucun.\nDiagnostics associés : Aucun",
                            "Diagnostic principal : " + df_codes.libelle_dp +"\nDiagnostic relié : Aucun.\nDiagnostics associés :\n"+ df_codes.libelle_das +"."),
                   np.where(df_codes.libelle_das=="","Diagnostic principal : " + df_codes.libelle_dp +".\nDiagnostic relié : " +df_codes.libelle_dr +".\nDiagnostics associés : Aucun.",
                            "Diagnostic principal : " + df_codes.libelle_dp +".\nDiagnostic relié : " + df_codes.libelle_dr +".\nDiagnostics associés :\n"+ df_codes.libelle_das +"."))



In [68]:
test = df_codes.loc[(df_codes.libelle_dr.isna())&(df_codes.libelle_das==""),"diag"].reset_index()
print(test.loc[0,"diag"])

Diagnostic principal : Anévrisme du coeur (I253).
Diagnostic relié : Aucun.
Diagnostics associés : Aucun


In [69]:
test = df_codes.loc[(df_codes.libelle_dr.isna())&(df_codes.libelle_das!=""),"diag"].reset_index()
print(test.loc[0,"diag"])

Diagnostic principal : Abcès de l'intestin (K630)
Diagnostic relié : Aucun.
Diagnostics associés :
Apnée du sommeil (G473),
Diabète sucré de type 2 insulinotraité, sans complication (E1190).


In [70]:
df_codes = df_codes[["visit_occurence_id","diag"]]

In [75]:
def get_dataset_partitions(df, train_split=0.8, val_split=0.1, test_split=0.1):
    assert (train_split + test_split + val_split) == 1
    
    # Only allows for equal validation and test splits
    assert val_split == test_split 

    # Specify seed to always have the same split distribution between runs
    df_sample = df.sample(frac=1, random_state=12)
    train_i = int(train_split * len(df))
    val_i = int(val_split * len(df)) + train_i
    indices_or_sections = [train_i,val_i]
    
    train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections)
    
    return train_ds, val_ds, test_ds

In [76]:
df_train, df_val, df_test = get_dataset_partitions(df_codes)

In [87]:
import csv
import json
import os

PATH_DATA = "../sample_data/"


In [100]:
# For all csv files (train,test,val)
# Write format text and generative jsonl

limit = None

input_text_file = PATH_DATA + "text.csv"

index=0
# Open note file
with open(input_text_file, 'r', encoding='utf-8') as infile:
    
    reader = csv.DictReader(infile)

    #Delete existing data
    try:
        os.remove(PATH_DATA + "train_text.jsonl")
        os.remove(PATH_DATA + "val_text.jsonl")
        os.remove(PATH_DATA + "test_text.jsonl")
    except OSError:
        pass
    
    for row in reader:

        visit_occurence_id = row['visit_occurence_id']

        #Specific HIVE export = concatenation of multiple files
        if visit_occurence_id== 'visit_occurence_id': 
            continue
        
        diags = df_codes.diag[df_codes.visit_occurence_id==visit_occurence_id].to_list()
        
        #Despite preprocessing there are still case where ICD coding is missing
        if len(diags)==0:
            continue
        
        diags = diags[0]

        if visit_occurence_id in df_train.visit_occurence_id.unique() :
            NAME = "train"
        elif visit_occurence_id in df_val.visit_occurence_id.unique() :
            NAME = "val"
        else:
            NAME = "test"

        
        
        output_file = PATH_DATA + NAME + "_text.jsonl"

        # Open text.jsonl
        with open(output_file, 'a+', encoding="utf-8") as outfile:
            # Read line
                text = row['text'] +  "\nDiagnostics CIM-10 du résumé standardisé de sortie PMSI : " + diags
                data = {"text": text}
                json.dump(data, outfile, separators=(',', ':'),  ensure_ascii=False)
                outfile.write('\n')
                index+=1
                
                if NAME == "val":
                    if index > 100:
                        break
                if limit is not None:
                    if index > limit:
                        break
        
        output_file = PATH_DATA + NAME + "_instruct.jsonl"
 
        # Open instruct.jsonl
        with open(output_file, 'a+', encoding="utf-8") as outfile:
                
            text = row['text']


            data = {
            "messages": [
                {
                    "role": "system",
                    "content": """Vous êtes un modèle de langage en française spécialisé dans le codage des diagnostics selon la classification internationale des maladies version 10 (CIM-10) pour les résumés standardisés de sortie du programme de médicalisation des systèmes d'information français (PMSI). A partir des comptes rendus d'hospitalisation vous donnerez les codes diagnostics CIM-10 que l'on peut retenir pour le séjours en distiguant diagnostic principal, diagnostic relié et diagnostics associés."""
                },

                {
                    "role": "user",
                    "content": "Générez le codage CIM-10 du résumé standardisé de sortie PMSI à partir du compte rendu d'hospitalisation suivant : " + text
                },
                {
                    "role": "assistant",
                    "content": "Codes CIM 10 retenus pour le résumé standardisé de sortie PMSI : " + diags
                }
            ]
            }
            json.dump(data, outfile, separators=(',', ':'),  ensure_ascii=False)
            outfile.write('\n')
            index+=1
            if NAME == "val":
                if index > 100:
                    break
            if limit is not None:
                if index > limit:
                    break

# Train model

## Model download

In [5]:
#!pip install huggingface_hub

In [6]:
# huggingface login
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
from huggingface_hub import snapshot_download

mistral_models_path = Path(f'{model_path}/mistral_models', '7B-v0.3')
mistral_models_path.mkdir(parents=True, exist_ok=True)

print(mistral_models_path) 

# Erreur restricted ?
# snapshot_download(repo_id="mistralai/Mistral-7B-v0.3", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)

/export/home/cse170020/models/mistral_models/7B-v0.3


In [8]:
# Alternatively, you can download the model from mistral

## Ne fonctionne pas (ou prend énormément de temps...) ?
# !wget -P /export/home/cse170020/models/mistral_models/7B-v0.3 https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar

In [9]:
# !DIR=/export/home/cse170020/models/mistral_models/7B-v0.3 && mkdir -p $DIR && tar -xf $DIR/mistral-7B-v0.3.tar -C $DIR

In [10]:
!ls {model_path+'/mistral_models/7B-v0.3'}

consolidated.safetensors  params.json  tokenizer.model.v3


### Test dataset

In [None]:
! python3 -m utils.validate_data --train_yaml example/instruct_icd_v1.yaml

### Train model

In [None]:
! torchrun --nproc-per-node 1 -m train example/instruct_icd_v1.yaml