In [8]:
from folder import StandardFolder
from polars_utils import *

import polars as pl
from pathlib import Path

prev_hn = pl.read_csv('D:/Prut/Warehouses/output/Dec23/n/Stroke/stroke_n_updatedSep2023_28112023.csv').to_series().to_list()
print(f'{len(prev_hn) = }')

len(prev_hn) = 27735


In [132]:
class StrokeIdentify(StandardFolder):
    def __init__(self, folder: str, streaming: bool = True) -> None:
        super().__init__(folder)
        self.streaming = streaming
        self.export_folder = Path('../output/Dec23/wh/complete')
        self.select_dx = [f'I6{i}' for i in range(10)] + ['G45']
        self.select_dx_re = '^' + '|^'.join(self.select_dx)
        # self.prev_hn = pl.read_csv('D:/Prut/Warehouses/output/Dec23/n/Stroke/stroke_n_updatedSep2023_28112023.csv').to_series().to_list()
        self.ran_all = False

    def get_dx(self, select: list = None):
        folder_path = self.dx
        to_concat = []
        for path in folder_path.iterdir():
            file = (
                scan_file(path)
                .select(pl.col(['ENC_HN', 'D001KEY', 'D035KEY']))
                .pipe(parse_dates, 'D001KEY')
            )
            if select is not None:
                file = file.filter(pl.col('D035KEY').str.contains(self.select_dx_re))
            file = file.group_by(pl.col(['ENC_HN', 'D001KEY'])).agg(pl.col('D035KEY')).with_columns(pl.col('D035KEY').list.unique().list.sort().list.join(', '))
            to_concat.append(file.collect(streaming=self.streaming))
        self.dx_df = pl.concat(to_concat).unique()


    def get_demo(self):
        folder_path = self.demo
        cols = ['ENC_HN', 'D020AT3', 'H2L1KEY', 'H6L1KEY', 'H6L1DES']
        new_col_names = ['ENC_HN', 'DOB', 'Sex', 'Province_ID', 'Province_Thai']
        to_concat = []
        for path in folder_path.iterdir():
            file = scan_file(path)
            if set(cols).issubset(set(file.columns)):
                file = file.select(cols).collect(streaming=self.streaming).pipe(parse_dates, 'D020AT3') # New bug: only works in dataframes, so must collect first
                to_concat.append(file)
        self.demo_df = pl.concat(to_concat).unique()
        self.demo_df = self.demo_df.rename(dict(zip(cols, new_col_names)))

    def run_all(self):
        self.get_dx(select=self.select_dx)
        print('dx')
        self.get_demo()
        print('demo')
        
        self.ran_all = True

    def merge(self):
        if not self.ran_all:
            raise Exception('Please run all first.')
        
        self.merged_df = (
            self.dx_df
            .join(self.demo_df, on=['ENC_HN'], how='left')
            .unique()
            
        )


        # Remove previous
            # .pipe(self.remove_previous_hn)
            # .pipe(print_n)
            # Clip dates
            # .pipe(clip_dates, date_col='D001KEY', start_month=7, start_year=2023, end_month=12, end_year=2023)
            # .pipe(print_n)



In [133]:
s = StrokeIdentify(folder='D:/Datalake/Data/20231231_fu_nc')
s.run_all()
s.merge()

readme not included.
dx
demo


Due to rules set by others on how to draw the flowchart, the following functions cannot be incorporated above.

In [158]:
def remove_previous_hn(lf: pl.LazyFrame, prev_hn=prev_hn) -> pl.LazyFrame:
    return lf.filter(~pl.col('ENC_HN').is_in(prev_hn))

def print_n(df: pl.DataFrame) -> pl.DataFrame:
        print(df['ENC_HN'].n_unique())
        return df

def remove_previous_stroke(df: pl.DataFrame) -> pl.DataFrame:
    # Remove patients who are diagnosed with stroke before entering the cohort
    return df.sort('D001KEY').group_by('ENC_HN', maintain_order=True).first().filter(~pl.col('D035KEY').map_elements(lambda x: 'I69' in x))


In [176]:
# Flow box 1
s.merged_df.pipe(clip_dates, date_col='D001KEY', start_month=7, start_year=2023, end_month=12, end_year=2023).pipe(remove_previous_hn).pipe(print_n)
# Flow box 2
s1 = s.merged_df.pipe(remove_previous_stroke).pipe(clip_dates, date_col='D001KEY', start_month=7, start_year=2023, end_month=12, end_year=2023)
s1.pipe(remove_previous_hn).pipe(print_n)
# Flow box 3
s2 = s1.pipe(remove_previous_hn).filter((pl.col('D001KEY') - pl.col('DOB')) >= pl.duration(days=365*18))
s2.pipe(print_n)
# Save output
prev = pl.read_csv('D:/Prut/Warehouses/output/Dec23/n/Stroke/stroke_n_updatedSep2023_28112023.csv').pipe(parse_dates, 'Date')
new = s2[['ENC_HN', 'D001KEY', 'D035KEY']].rename({'D001KEY': 'Date', 'D035KEY': 'ICD10'})
output = pl.concat([new, prev])
output.pipe(print_n)

output.write_csv('D:/Prut/Warehouses/output/Dec23/n/Stroke/stroke_n_updated_05032024.csv')

2788
1095
1075
28810


Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("D035KEY").map_elements(lambda x: ...)
with this one instead:
  + 'I69'.is_in(pl.col("D035KEY"))

  return df.sort('D001KEY').group_by('ENC_HN', maintain_order=True).first().filter(~pl.col('D035KEY').map_elements(lambda x: 'I69' in x))
