In [20]:
import os
import re
import csv
import pandas as pd
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
import transformers
import torch
import logging
import warnings
from transformers import logging as transformers_logging
from tqdm import tqdm

# Suppress warnings and set logging levels (as before)
warnings.filterwarnings('ignore')
transformers_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=RuntimeWarning, module="torch.nn.parallel")


In [27]:
def extract_text_from_pdf(pdf_path):
    doc = DocumentFile.from_pdf(pdf_path)
    predictor = ocr_predictor(pretrained=True)
    result = predictor(doc)
    json_output = result.export()
    
    extracted_text = ""
    for page in json_output['pages']:
        for block in page['blocks']:
            for line in block['lines']:
                for word in line['words']:
                    extracted_text += word['value'] + ' '
                extracted_text += '\n'
    
    return extracted_text.strip()

def extract_patient_info_from_filename(filename):
    match = re.match(r'(\d{4}-\d{2}-\d{2})\s+(\d+)-[RC]\.pdf', filename)
    if match:
        date = match.group(1)
        nhs_number = match.group(2)
        return date, nhs_number
    return None, None

def initialize_pipeline(gpu_id=2):
    model_id = "HPAI-BSC/Llama3-Aloe-8B-Alpha"
    try:
        if torch.cuda.is_available():
            device = f"cuda:{gpu_id}"
        else:
            print("CUDA is not available. Using CPU.")
            device = "cpu"
        
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device=device,
        )
        return pipeline
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

def extract_doctor_specialty(pipeline, texts):
    try:
        messages = [
            {"role": "system", "content": "Extract the doctor's specialty from the following clinic letter."}
        ] + [{"role": "user", "content": text} for text in texts]
        
        prompts = pipeline.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        terminators = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        outputs = pipeline(
            prompts,
            max_new_tokens=50,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            batch_size=len(texts)  # Process all texts in one batch
        )
        return [output["generated_text"].split("Doctor's specialty:")[-1].strip() for output in outputs]
    except Exception as e:
        return [f"An error occurred: {str(e)}"] * len(texts)


def process_patient_folder(folder_path, pipeline):
    referral_letter = None
    clinic_letters = []
    
    for file in os.listdir(folder_path):
        if file.endswith('-R.pdf'):
            referral_letter = file
        elif file.endswith('-C.pdf'):
            clinic_letters.append(file)
    
    if not referral_letter:
        print(f"Warning: No referral letter found in {folder_path}")
        return None
    
    date, nhs_number = extract_patient_info_from_filename(referral_letter)
    if date is None or nhs_number is None:
        print(f"Warning: Could not extract info from filename: {referral_letter}")
        return None
    
    referral_text = extract_text_from_pdf(os.path.join(folder_path, referral_letter))
    
    clinic_texts = [extract_text_from_pdf(os.path.join(folder_path, letter)) for letter in clinic_letters]
    specialties = extract_doctor_specialty(pipeline, clinic_texts)
    
    return {
        'NHS_Number': nhs_number,
        'Referral_Date': date,
        'Referral_Letter_Text': referral_text,
        'Doctor_Specialties': ', '.join(specialties),
        'Num_Clinic_Letters': len(clinic_letters)
    }

def main():
    gpu_id = 2  # This will use the third GPU
    pipeline = initialize_pipeline(gpu_id)
    
    base_folder = '/home/swleocresearch/Desktop/triage-ai/datasets/cleaned_dataset_ali'
    output_data = []
    
    # Get total number of patient folders
    total_folders = sum(1 for entry in os.listdir(base_folder) if os.path.isdir(os.path.join(base_folder, entry)))
    
    # Main progress bar for patient folders
    with tqdm(total=total_folders, desc="Processing patients", unit="patient") as pbar:
        for patient_folder in os.listdir(base_folder):
            folder_path = os.path.join(base_folder, patient_folder)
            if os.path.isdir(folder_path):
                result = process_patient_folder(folder_path, pipeline)
                if result:
                    output_data.append(result)
                pbar.update(1)
                pbar.set_postfix({"Letters Processed": sum(item['Num_Clinic_Letters'] for item in output_data)})

                break
    
    # Write output to CSV
    with open('output.csv', 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = ['NHS_Number', 'Referral_Date', 'Referral_Letter_Text', 'Doctor_Specialties', 'Num_Clinic_Letters']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        writer.writeheader()
        for row in output_data:
            writer.writerow(row)
    
    print("output.csv has been created successfully.")
    print(f"Total patients processed: {len(output_data)}")
    print(f"Total clinic letters processed: {sum(item['Num_Clinic_Letters'] for item in output_data)}")

In [28]:
if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Processing patients:   0%|▏                                 | 1/253 [00:44<3:08:03, 44.78s/patient, Letters Processed=8]

output.csv has been created successfully.
Total patients processed: 1
Total clinic letters processed: 8



