In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # Load the pre-trained CORe model
# tokenizer = AutoTokenizer.from_pretrained("DATEXIS/CORe-clinical-diagnosis-prediction")
# model = AutoModelForSequenceClassification.from_pretrained("DATEXIS/CORe-clinical-diagnosis-prediction")
# model.to(device)

tokenizer = AutoTokenizer.from_pretrained("ashishkgpian/biobert_icd9_classifier_ehr")
model = AutoModelForSequenceClassification.from_pretrained("ashishkgpian/biobert_icd9_classifier_ehr", ignore_mismatched_sizes=True)
model.to(device)


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [2]:
input = "CHIEF COMPLAINT: Headaches\n\nPRESENT ILLNESS: 58yo man w/ hx of hypertension, AFib on coumadin presented to ED with the worst headache of his life."

In [3]:
tokenized_input = tokenizer(input, return_tensors="pt")
tokenized_input = {k: v.to(device) for k, v in tokenized_input.items()}

In [4]:
output = model(**tokenized_input)
output

SequenceClassifierOutput(loss=None, logits=tensor([[ -8.1184,  -3.9650,  -8.6491,  ..., -11.6154, -10.9156, -16.0345]],
       device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [5]:
import torch

predictions = torch.sigmoid(output.logits)
predicted_labels = [model.config.id2label[_id] for _id in (predictions > 0.2).nonzero()[:, 1].tolist()]
predicted_labels

['401', 'V586', '427', '431']

In [6]:
classes = str('403 486 582 585 425 276 710 724 458 287 285 275 583 558 327 228 338 789 790 V451 531 410 414 725 191 331 530 411 482 272 305 194 197 255 424 584 998 682 511 599 428 349 401 V100 V453 V586 041 251 E932 V300 V053 V290 571 070 250 570 572 286 518 038 280 263 995 303 244 112 881 903 955 E956 745 762 441 496 447 440 997 274 427 V104 V101 V120 V090 569 560 491 V458 433 436 493 996 416 V310 765 769 774 770 747 776 772 362 198 V103 746 766 V293 853 780 E888 730 357 430 293 443 V158 396 365 135 311 E935 721 214 437 242 600 189 304 711 800 E814 873 781 378 951 767 431 294 042 V141 V071 764 775 969 295 E950 266 779 355 553 965 E850 E853 426 804 E916 202 V502 398 707 348 787 564 V428 238 300 788 332 V107 V433 E879 861 423 E966 200 555 771 270 335 723 079 851 807 864 865 860 413 782 V108 507 512 752 162 783 778 333 785 136 799 E931 157 574 568 E878 722 719 V125 296 478 V170 805 596 E880 822 733 578 459 438 008 V098 185 967 225 V457 389 412 593 345 201 515 E933 278 492 715 415 V105 535 608 E870 V058 513 709 E821 V173 824 911 913 E812 576 203 281 580 V450 216 V340 579 693 351 088 714 E849 307 421 786 E942 959 E928 588 364 V642 V025 252 283 784 611 622 289 446 729 V498 V456 795 E854 V667 155 V130 882 852 957 E815 466 792 434 342 153 E934 481 910 456 453 867 273 532 806 V422 V541 556 394 444 924 E960 514 763 218 359 340 999 451 324 E939 537 737 455 E884 V427 591 592 577 557 575 356 368 552 500 750 253 292 E937 211 288 773 314 V652 432 379 435 E930 199 V641 494 966 758 E855 741 918 V436 078 562 820 801 839 E881 V584 731 E885 812 156 567 696 501 712 V707 215 754 753 508 876 720 V442 871 958 802 847 397 196 346 E968 510 404 360 376 370 V026 904 928 821 823 150 573 850 V497 E938 V533 V556 728 870 V874 V153 V644 V600 521 301 164 054 344 464 442 V150 282 V08 891 808 866 902 117 484 760 V048 691 519 528 320 369 685 V625 794 793 318 V441 761 936 E915 457 395 053 V113 V632 386 623 290 204 271 E819 811 813 884 E813 751 366 297 V440 473 E910 V420 057 536 152 970 485 235 372 E882 127 160 170 V880 595 909 V443 490 343 319 130 698 E823 246 854 868 872 982 151 V853 980 E980 291 517 268 487 E866 796 V452 036 354 648 701 V063 V038 227 614 533 736 942 E924 240 921 V454 977 759 768 923 E816 681 138 358 950 922 205 990 009 619 417 279 257 E860 755 991 E957 241 810 920 V461 V127 261 429 550 874 756 935 831 718 962 E858 803 480 674 277 880 879 377 529 047 083 835 462 336 E947 V160 420 317 454 E883 840 V550 960 586 933 597 350 E911 742 V614 298 V551 620 716 V462 V180 706 565 452 825 322 154 040 110 605 607 461 704 713 945 052 948 323 325 934 516 039 975 971 994 666 V111 907 E929 566 603 405 049 237 V161 V553 262 743 422 337 625 757 527 309 815 V163 402 869 E912 188 590 V852 V446 E852 886 E919 183 862 875 877 890 E944 E936 V444 598 V552 226 E818 617 E958 V123 748 968 V298 465 972 E826 905 E969 744 E829 V301 388 V146 V151 887 375 334 E848 E918 284 E876 260 987 E890 834 522 692 V588 310 863 E834 192 035 V174 171 738 220 477 212 172 V548 726 526 V099 777 749 E922 952 V320 901 542 449 V011 963 E822 524 V052 V539 144 445 321 380 604 383 587 137 845 695 V496 180 618 V102 540 525 916 174 V628 892 816 V171 520 708 176 791 V854 E906 V714 V554 V435 883 927 V434 007 581 V202 140 642 644 654 V270 V252 193 V838 V555 139 V195 V068 601 826 694 626 956 245 919 299 727 684 647 E941 V850 665 391 308 633 639 V230 V061 223 269 V183 046 534 361 673 643 986 005 034 382 239 232 V169 E901 908 634 836 616 E917 734 V698 133 E887 V445 V155 E949 142 E987 236 470 463 E940 229 448 702 182 E825 V851 814 V881 259 906 161 E891 830 E953 195 093 472 914 E988 930 543 686 900 075 705 939 381 V311 V168 018 004 917 483 656 641 217 V291 V164 E943 134 635 659 E920 506 E869 111 096 094 123 158 141 243 690 097 632 989 964 027 V596 373 V017 254 932 187 353 669 V504 602 843 912 374 983 E864 031 210 114 646 077 V018 670 615 V638 V135 938 V580 680 878 E965 471 652 663 658 V272 213 032 148 V643 V148 V062 E989 E927 131 233 V040 V066 125 V503 V581 V292 V192 700 703 209 V029 208 697 E871 184 015 146 V140 V154 992 249 149 V142 844 175 V542 363 V152 V106 V688 V265 012 885 E955 V530 385 V124 V741 390 474 627 817 230 E817 V198 E862 258 V463 735 V024 V640 976 E861 V765 V023 V626 E828 V188 341 V560 798 V448 893 495 084 523 V653 953 V549 V095 V182 621 475 V425 058 306 V165 551 E831 V136 V109 256 219 221 961 985 828 671 E820 897 V840 926 V421 048 594 896 082 E986 541 145 267 683 V097 732 265 011 E801 V185 664 V620 E840 V166 V468 629 115 V587 E908 120 V708 098 V469 V694 E824 E970 121 838 832 460 013 V239 944 V189 946 118 326 E945 645 352 159 E967 V618 147 V908 941 312 624 V186 V145 661 010 E865 091 E886 649 E905 E962 V612 E959 502 V438 V222 163 947 V162 E946 V716 315 367 V540 846 717 V561 V175 842 V138 V703 V583 841 672 062 488 347 339 E841 086 V400 E985 655 974 V289 V604 V074 V728 371 190 V126 090 143 943 V611 V331 085 V172 E835 668 740 V167 V558 E851 E811 V430 837 V072 V431 302 E923 V110 E900 V562 E963 E964 V118 V624 E800 988 833 023 V020 021 003 V660 E806 313 E954 V860 660 V449 231 V602 186 E863 E874 V721 V181 651 033 V654 E804 330 610 384 E838 E001 973 819 014 132 E899 925 207 V861 E002 E030 E000 894 E873 E999 E976 E003 V016 E805 045 V610 V078 V510 E029 848 E006 V403 122 V536 E013 E019 173 E913 677 E008 V568 V143 V091 V872 066 V601 116 V882 V065 538 V655 316 E007 E016 E921 V902 206 V254 099 V489 V870 E977 628 V250 E982 V486 539 V073 937 V812 030 V271 589 V672 V671 E926 E925 E857 V537 954 E827 657 V910 V789 V037 E975 V045 V848 393 V426 179 387 V903 E856 V901 915').split(' ')

print(classes)

['403', '486', '582', '585', '425', '276', '710', '724', '458', '287', '285', '275', '583', '558', '327', '228', '338', '789', '790', 'V451', '531', '410', '414', '725', '191', '331', '530', '411', '482', '272', '305', '194', '197', '255', '424', '584', '998', '682', '511', '599', '428', '349', '401', 'V100', 'V453', 'V586', '041', '251', 'E932', 'V300', 'V053', 'V290', '571', '070', '250', '570', '572', '286', '518', '038', '280', '263', '995', '303', '244', '112', '881', '903', '955', 'E956', '745', '762', '441', '496', '447', '440', '997', '274', '427', 'V104', 'V101', 'V120', 'V090', '569', '560', '491', 'V458', '433', '436', '493', '996', '416', 'V310', '765', '769', '774', '770', '747', '776', '772', '362', '198', 'V103', '746', '766', 'V293', '853', '780', 'E888', '730', '357', '430', '293', '443', 'V158', '396', '365', '135', '311', 'E935', '721', '214', '437', '242', '600', '189', '304', '711', '800', 'E814', '873', '781', '378', '951', '767', '431', '294', '042', 'V141', 'V07

In [7]:
import pandas as pd

icd_df = pd.read_csv('D_ICD_DIAGNOSES (1).csv')
def preprocessing(test_df) : 
    test_df.loc[
        test_df['ICD9_CODE'].str.startswith("V"), 'ICD9_CODE'] = test_df.ICD9_CODE.apply(
        lambda x: x[:4])
    test_df.loc[
        test_df['ICD9_CODE'].str.startswith("E"), 'ICD9_CODE'] = test_df.ICD9_CODE.apply(
        lambda x: x[:4])
    test_df.loc[(~test_df.ICD9_CODE.str.startswith("E")) & (
        ~test_df.ICD9_CODE.str.startswith("V")), 'ICD9_CODE'] = test_df.ICD9_CODE.apply(
        lambda x: x[:3])
    return test_df

icd_df = preprocessing(icd_df)
icd_df.head()

Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,174,11,TB pneumonia-oth test,"Tuberculous pneumonia [any form], tubercle bac..."
1,175,11,TB pneumothorax-unspec,"Tuberculous pneumothorax, unspecified"
2,176,11,TB pneumothorax-no exam,"Tuberculous pneumothorax, bacteriological or h..."
3,177,11,TB pneumothorx-exam unkn,"Tuberculous pneumothorax, bacteriological or h..."
4,178,11,TB pneumothorax-micro dx,"Tuberculous pneumothorax, tubercle bacilli fou..."


In [8]:
icd_df_titles = list(icd_df[icd_df.ICD9_CODE.isin(classes)].SHORT_TITLE)


In [9]:
def get_predictions(input_text, threshold = 0.8) : 
    """give the EHR/Symptoms as input, and 
    get the disease codes (matching in the classes under consideration) 
    as the output"""

    tokenized_input =  tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=512,  # BERT's maximum sequence length
        padding='max_length'
    )
    tokenized_input = {k: v.to(device) for k, v in tokenized_input.items()}
    output = model(**tokenized_input)
    predictions = torch.sigmoid(output.logits)
    predicted_labels = [model.config.id2label[_id] for _id in (predictions > threshold).nonzero()[:, 1].tolist()]
    classes = str('403 486 582 585 425 276 710 724 458 287 285 275 583 558 327 228 338 789 790 V451 531 410 414 725 191 331 530 411 482 272 305 194 197 255 424 584 998 682 511 599 428 349 401 V100 V453 V586 041 251 E932 V300 V053 V290 571 070 250 570 572 286 518 038 280 263 995 303 244 112 881 903 955 E956 745 762 441 496 447 440 997 274 427 V104 V101 V120 V090 569 560 491 V458 433 436 493 996 416 V310 765 769 774 770 747 776 772 362 198 V103 746 766 V293 853 780 E888 730 357 430 293 443 V158 396 365 135 311 E935 721 214 437 242 600 189 304 711 800 E814 873 781 378 951 767 431 294 042 V141 V071 764 775 969 295 E950 266 779 355 553 965 E850 E853 426 804 E916 202 V502 398 707 348 787 564 V428 238 300 788 332 V107 V433 E879 861 423 E966 200 555 771 270 335 723 079 851 807 864 865 860 413 782 V108 507 512 752 162 783 778 333 785 136 799 E931 157 574 568 E878 722 719 V125 296 478 V170 805 596 E880 822 733 578 459 438 008 V098 185 967 225 V457 389 412 593 345 201 515 E933 278 492 715 415 V105 535 608 E870 V058 513 709 E821 V173 824 911 913 E812 576 203 281 580 V450 216 V340 579 693 351 088 714 E849 307 421 786 E942 959 E928 588 364 V642 V025 252 283 784 611 622 289 446 729 V498 V456 795 E854 V667 155 V130 882 852 957 E815 466 792 434 342 153 E934 481 910 456 453 867 273 532 806 V422 V541 556 394 444 924 E960 514 763 218 359 340 999 451 324 E939 537 737 455 E884 V427 591 592 577 557 575 356 368 552 500 750 253 292 E937 211 288 773 314 V652 432 379 435 E930 199 V641 494 966 758 E855 741 918 V436 078 562 820 801 839 E881 V584 731 E885 812 156 567 696 501 712 V707 215 754 753 508 876 720 V442 871 958 802 847 397 196 346 E968 510 404 360 376 370 V026 904 928 821 823 150 573 850 V497 E938 V533 V556 728 870 V874 V153 V644 V600 521 301 164 054 344 464 442 V150 282 V08 891 808 866 902 117 484 760 V048 691 519 528 320 369 685 V625 794 793 318 V441 761 936 E915 457 395 053 V113 V632 386 623 290 204 271 E819 811 813 884 E813 751 366 297 V440 473 E910 V420 057 536 152 970 485 235 372 E882 127 160 170 V880 595 909 V443 490 343 319 130 698 E823 246 854 868 872 982 151 V853 980 E980 291 517 268 487 E866 796 V452 036 354 648 701 V063 V038 227 614 533 736 942 E924 240 921 V454 977 759 768 923 E816 681 138 358 950 922 205 990 009 619 417 279 257 E860 755 991 E957 241 810 920 V461 V127 261 429 550 874 756 935 831 718 962 E858 803 480 674 277 880 879 377 529 047 083 835 462 336 E947 V160 420 317 454 E883 840 V550 960 586 933 597 350 E911 742 V614 298 V551 620 716 V462 V180 706 565 452 825 322 154 040 110 605 607 461 704 713 945 052 948 323 325 934 516 039 975 971 994 666 V111 907 E929 566 603 405 049 237 V161 V553 262 743 422 337 625 757 527 309 815 V163 402 869 E912 188 590 V852 V446 E852 886 E919 183 862 875 877 890 E944 E936 V444 598 V552 226 E818 617 E958 V123 748 968 V298 465 972 E826 905 E969 744 E829 V301 388 V146 V151 887 375 334 E848 E918 284 E876 260 987 E890 834 522 692 V588 310 863 E834 192 035 V174 171 738 220 477 212 172 V548 726 526 V099 777 749 E922 952 V320 901 542 449 V011 963 E822 524 V052 V539 144 445 321 380 604 383 587 137 845 695 V496 180 618 V102 540 525 916 174 V628 892 816 V171 520 708 176 791 V854 E906 V714 V554 V435 883 927 V434 007 581 V202 140 642 644 654 V270 V252 193 V838 V555 139 V195 V068 601 826 694 626 956 245 919 299 727 684 647 E941 V850 665 391 308 633 639 V230 V061 223 269 V183 046 534 361 673 643 986 005 034 382 239 232 V169 E901 908 634 836 616 E917 734 V698 133 E887 V445 V155 E949 142 E987 236 470 463 E940 229 448 702 182 E825 V851 814 V881 259 906 161 E891 830 E953 195 093 472 914 E988 930 543 686 900 075 705 939 381 V311 V168 018 004 917 483 656 641 217 V291 V164 E943 134 635 659 E920 506 E869 111 096 094 123 158 141 243 690 097 632 989 964 027 V596 373 V017 254 932 187 353 669 V504 602 843 912 374 983 E864 031 210 114 646 077 V018 670 615 V638 V135 938 V580 680 878 E965 471 652 663 658 V272 213 032 148 V643 V148 V062 E989 E927 131 233 V040 V066 125 V503 V581 V292 V192 700 703 209 V029 208 697 E871 184 015 146 V140 V154 992 249 149 V142 844 175 V542 363 V152 V106 V688 V265 012 885 E955 V530 385 V124 V741 390 474 627 817 230 E817 V198 E862 258 V463 735 V024 V640 976 E861 V765 V023 V626 E828 V188 341 V560 798 V448 893 495 084 523 V653 953 V549 V095 V182 621 475 V425 058 306 V165 551 E831 V136 V109 256 219 221 961 985 828 671 E820 897 V840 926 V421 048 594 896 082 E986 541 145 267 683 V097 732 265 011 E801 V185 664 V620 E840 V166 V468 629 115 V587 E908 120 V708 098 V469 V694 E824 E970 121 838 832 460 013 V239 944 V189 946 118 326 E945 645 352 159 E967 V618 147 V908 941 312 624 V186 V145 661 010 E865 091 E886 649 E905 E962 V612 E959 502 V438 V222 163 947 V162 E946 V716 315 367 V540 846 717 V561 V175 842 V138 V703 V583 841 672 062 488 347 339 E841 086 V400 E985 655 974 V289 V604 V074 V728 371 190 V126 090 143 943 V611 V331 085 V172 E835 668 740 V167 V558 E851 E811 V430 837 V072 V431 302 E923 V110 E900 V562 E963 E964 V118 V624 E800 988 833 023 V020 021 003 V660 E806 313 E954 V860 660 V449 231 V602 186 E863 E874 V721 V181 651 033 V654 E804 330 610 384 E838 E001 973 819 014 132 E899 925 207 V861 E002 E030 E000 894 E873 E999 E976 E003 V016 E805 045 V610 V078 V510 E029 848 E006 V403 122 V536 E013 E019 173 E913 677 E008 V568 V143 V091 V872 066 V601 116 V882 V065 538 V655 316 E007 E016 E921 V902 206 V254 099 V489 V870 E977 628 V250 E982 V486 539 V073 937 V812 030 V271 589 V672 V671 E926 E925 E857 V537 954 E827 657 V910 V789 V037 E975 V045 V848 393 V426 179 387 V903 E856 V901 915').split(' ')

    predicted_labels = list(set(classes).intersection(set(predicted_labels)))

    return predicted_labels

    

    

In [10]:
import os
import json
import pandas as pd

directory_path = 'samples_latest'
data = []

for filename in os.listdir(directory_path):
    if filename.endswith('.json'):
        file_path = os.path.join(directory_path, filename)
        
        with open(file_path, 'r') as file:
            file_data = json.load(file)
            data.append(file_data)

In [11]:
symptoms_df = pd.DataFrame(data)
raw_test_df = pd.read_csv('new_split/test_split.csv')
# print(raw_test_df.columns)
unique_ids = list(raw_test_df.id)
raw_test_df = raw_test_df.drop('Unnamed: 0',axis =1)
raw_test_df = pd.merge(symptoms_df, raw_test_df, how='inner', on='id')
raw_test_df.head(3)

Unnamed: 0,Symptoms,Diseases,id,icd_9_desc,text,long_texts,short_texts,discharge_summary,short_codes
0,"[fatigued, not herself, febrile, hypotensive, ...","[CVA, peptic ulcer disease, reflux esophagitis...",196005,"Subendocardial infarction, initial episode of ...","CHIEF COMPLAINT: ""off"", status post fall at [*...","Subendocardial infarction, initial episode of ...","Subendo infarct, initial,Atrial fibrillation,U...",Admission Date: [**2117-5-11**] ...,"410,427,599,428,486,799,276,530,285,780,V125"
1,"[melena, vague abdominal pain, bloody diarrhea...","[polycystic kidney disease, Caroli's disease, ...",194492,"Blood in stool,Portal hypertension,Other anoma...",CHIEF COMPLAINT: melena\n\nPRESENT ILLNESS: Ms...,"Blood in stool,Portal hypertension,Other anoma...","Blood in stool,Portal hypertension,Biliary & l...",Admission Date: [**2129-9-16**] ...,"578,572,751,V420,588,753,456,535,274,996,585,2..."
2,"[Nausea, Vomiting, Hypertension]","[Gastroparesis, Diabetes Mellitus Type 1, Hype...",103789,"Hypertensive chronic kidney disease, malignant...","CHIEF COMPLAINT: Nausea, vomiting and hyperten...","Hypertensive chronic kidney disease, malignant...","Mal hy kid w cr kid I-IV,DMI ketoacd uncontrol...",Admission Date: [**2155-9-3**] D...,"403,250,276,V586,536,272,724"


In [12]:
raw_test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8733 entries, 0 to 8732
Data columns (total 9 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   Symptoms           8733 non-null   object
 1   Diseases           8733 non-null   object
 2   id                 8733 non-null   int64 
 3   icd_9_desc         8733 non-null   object
 4   text               8733 non-null   object
 5   long_texts         8733 non-null   object
 6   short_texts        8733 non-null   object
 7   discharge_summary  8733 non-null   object
 8   short_codes        8733 non-null   object
dtypes: int64(1), object(8)
memory usage: 614.2+ KB


In [13]:
def create_binary_matrix(labels, classes):
    """
    Creates a binary matrix from a list of labels and a list of classes.
    
    Parameters:
    labels (list): A list of lists, where each inner list contains the labels for a single data point.
    classes (list): A list of class names.
    
    Returns:
    numpy.ndarray: A binary matrix where each row represents a data point and each column represents a class.
    """
    binary_matrix = np.zeros((len(labels), len(classes)), dtype=int)
    for i, label_list in enumerate(labels):
        for label in label_list:
            if label in classes:
                idx = classes.index(label)
                binary_matrix[i, idx] = 1       
    return binary_matrix

In [14]:
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
import numpy as np

def roc_auc_f1_precision_recall(probs, labels, multilabel=False, average='macro', multi_class='ovo'):
    """
    Calculates ROC AUC, F1, precision, and recall scores (macro and micro) for multilabel or multiclass classification.
    
    Parameters:
    probs (numpy.ndarray): A 2D array of probabilities, where each row represents a data point and each column represents a class.
    labels (numpy.ndarray or list): The true labels for the data points.
    multilabel (bool, optional): Whether the problem is a multilabel classification task. Defaults to False.
    average (str, optional): The method to use for averaging scores across classes. Defaults to 'macro'.
    multi_class (str, optional): The strategy to use for multi-class ROC AUC calculation. Defaults to 'ovo' (one-vs-one).
    
    Returns:
    dict: A dictionary with ROC AUC, F1, precision, and recall scores (macro and micro).
    """
    if isinstance(labels, list):
        labels = np.array(labels, dtype=int)
    else:
        labels = labels.astype(int)

    # Filter relevant columns if multilabel is True
    y_score = probs
    if multilabel:
        # Identify classes with at least one positive label
        present_classes = np.any(labels == 1, axis=0)
        labels = labels[:, present_classes]
        y_score = np.array(probs)[:, present_classes]

    # Calculate ROC AUC score
    roc_auc = roc_auc_score(y_true=labels, y_score=y_score, average=average, multi_class=multi_class)

    # Binarize the predictions for F1, precision, and recall calculation
    preds = (y_score >= 0.5).astype(int)
    
    # Calculate F1 scores
    f1_macro = f1_score(labels, preds, average='macro')
    f1_micro = f1_score(labels, preds, average='micro')
    
    # Calculate precision scores
    precision_macro = precision_score(labels, preds, average='macro', zero_division=0)
    precision_micro = precision_score(labels, preds, average='micro', zero_division=0)
    
    # Calculate recall scores
    recall_macro = recall_score(labels, preds, average='macro', zero_division=0)
    recall_micro = recall_score(labels, preds, average='micro', zero_division=0)
    
    return {
        'roc_auc': roc_auc,
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macro,
        'recall_micro': recall_micro
    }


In [15]:
str(raw_test_df.iloc[0].Symptoms)[2:-2]

"fatigued', 'not herself', 'febrile', 'hypotensive', 'elevated WBC', 'positive cardiac enzymes', 'atrial fibrillation', 'shortness of breath', 'congestive heart failure', 'mental status changes"

In [16]:
print(get_predictions(str(raw_test_df.iloc[0].Symptoms)[2:-2]), raw_test_df.iloc[0].short_codes.split(','))

['428', '427'] ['410', '427', '599', '428', '486', '799', '276', '530', '285', '780', 'V125']


In [22]:
from tqdm import tqdm
true_labels = []
predicted_labels = []
# Wrap raw_test_df.iterrows() with tqdm
for i, j in tqdm(raw_test_df.iterrows(), total=len(raw_test_df)):
    pred_label = get_predictions(str(j.text) + ' ' + str(j.Symptoms), 0.0643745388060965)
    true_label = j.short_codes.split(',')
    true_labels.append(true_label)
    predicted_labels.append(pred_label)



100%|██████████| 8733/8733 [07:23<00:00, 19.67it/s]


In [23]:
len(true_labels), len(predicted_labels)

(8733, 8733)

In [24]:
print(predicted_labels[3])

['437']


In [25]:
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

# Define batch size
batch_size = 200 # Adjust based on memory limits

metrics_list = []

# Process batches directly from the labels, converting each batch to binary format on-the-fly
for i in tqdm(range(0, len(true_labels), batch_size), desc="Processing Batches"):
    batch_true_labels = true_labels[i:i + batch_size]
    batch_pred_labels = predicted_labels[i:i + batch_size]


    # Create binary matrices for each batch
    batch_true_matrix = create_binary_matrix(batch_true_labels, classes)
    batch_pred_matrix = create_binary_matrix(batch_pred_labels, classes)
    
    # Calculate metrics for the current batch
    batch_metrics = roc_auc_f1_precision_recall(batch_pred_matrix, batch_true_matrix, multilabel=True)
    metrics_list.append(batch_metrics)


# Average metrics across all batches
final_metrics = {k: sum(d[k] for d in metrics_list) / len(metrics_list) for k in metrics_list[0]}
print(final_metrics)


Processing Batches: 100%|██████████| 44/44 [00:38<00:00,  1.14it/s]

{'roc_auc': 0.6513005144308636, 'f1_macro': 0.20159803401340123, 'f1_micro': 0.36805457656967383, 'precision_macro': 0.15559283007563474, 'precision_micro': 0.2577387634438119, 'recall_macro': 0.36141784323089454, 'recall_micro': 0.6438525739848772}





In [None]:
true_binary_matrix = create_binary_matrix(true_labels[100:200], classes)  ## Note the list wrapping
pred_binary_matrix = create_binary_matrix(predicted_labels[100:200], classes)

metrics = roc_auc_f1_precision_recall(pred_binary_matrix, true_binary_matrix, multilabel=True)
print(metrics)

{'roc_auc': 0.6510764089568829, 'f1_macro': 0.20027328276086817, 'f1_micro': 0.33839373163565134, 'precision_macro': 0.15197598560781933, 'precision_micro': 0.23087203474774473, 'recall_macro': 0.38999073165256876, 'recall_micro': 0.6333638863428047}


In [26]:
import pandas as pd

rows = []
for true, pred in zip(true_labels, predicted_labels):
    # Convert to sets for efficient matching
    true_set = set(true)
    pred_set = set(pred)
    matching_words = true_set & pred_set
    
    rows.append({
        'True labels': true,  # Keep as list
        'Predicted Labels': pred,  # Keep as list
        'Matching Words': list(matching_words),  # Matching words as list
        'Match Count': len(matching_words)  # Add count of matches
    })

df = pd.DataFrame(rows)
# Convert lists to strings for better display (optional)
df['True labels'] = df['True labels'].apply(str)
df['Predicted Labels'] = df['Predicted Labels'].apply(str)
df['Matching Words'] = df['Matching Words'].apply(str)

df.head()

Unnamed: 0,True labels,Predicted Labels,Matching Words,Match Count
0,"['410', '427', '599', '428', '486', '799', '27...","['038', '785', '584', '401', '276', '414', '51...","['276', '599', '427', '486', '428', '285', '410']",7
1,"['578', '572', '751', 'V420', '588', '753', '4...","['403', '710', '530', 'V451', '250', '996', '5...","['585', 'E878', '572', '276', '996', '535', 'V...",9
2,"['403', '250', '276', 'V586', '536', '272', '7...","['403', '564', '414', '311', '530', 'V458', '2...","['403', '272', '276', '250', '724', '536', 'V5...",7
3,"['437', '305']",['437'],['437'],1
4,"['430', '276', '780', 'V103', '401', '250']","['401', '431', '250', '430', '285', '780']","['780', '401', '430', '250']",4


In [30]:
import optuna
import numpy as np
from sklearn.metrics import f1_score


def objective(trial):
    threshold = trial.suggest_float('threshold', 0.01, 0.5)
    
    true_labels = []
    predicted_labels = []
    
    for i, j in raw_test_df.iterrows():
        pred_label = get_predictions(j.text, threshold)
        true_label = j.short_codes.split(',')
        true_labels.append(true_label)
        predicted_labels.append(pred_label)
        if i == 1000:
            break
    
    true_binary_matrix = create_binary_matrix(true_labels, classes)
    pred_binary_matrix = create_binary_matrix(predicted_labels, classes)
    
    roc_auc =  roc_auc_f1_precision_recall(pred_binary_matrix, true_binary_matrix, multilabel=True)['f1_macro']
    
    return roc_auc

def find_optimal_threshold(n_trials=100):
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=n_trials)
    
    print('Number of finished trials:', len(study.trials))
    print('Best trial:')
    trial = study.best_trial
    
    print('  Value:', trial.value)
    print('  Params:')
    print('    threshold:', trial.params['threshold'])
    
    return trial.params['threshold']

optimal_threshold = find_optimal_threshold(n_trials=30) 

true_labels = []
predicted_labels = []
for i, j in raw_test_df.iterrows():
    pred_label = get_predictions(j.text, optimal_threshold)
    true_label = j.short_codes.split(',')
    true_labels.append(true_label)
    predicted_labels.append(pred_label)
    if i == 100:
        break

true_binary_matrix = create_binary_matrix(true_labels, classes)
pred_binary_matrix = create_binary_matrix(predicted_labels, classes)

final_metrics = roc_auc_f1_precision_recall(pred_binary_matrix, true_binary_matrix, multilabel=True)
print("\nFinal metrics with optimal threshold:")
print(final_metrics)

[I 2024-11-15 12:18:19,573] A new study created in memory with name: no-name-0ea8eed5-4ef3-452f-8129-253c5c847213


[I 2024-11-15 12:19:30,673] Trial 0 finished with value: 0.0948748286149849 and parameters: {'threshold': 0.3117835546194627}. Best is trial 0 with value: 0.0948748286149849.
[I 2024-11-15 12:20:45,716] Trial 1 finished with value: 0.1030823673750047 and parameters: {'threshold': 0.015647536761748158}. Best is trial 1 with value: 0.1030823673750047.
[I 2024-11-15 12:22:00,513] Trial 2 finished with value: 0.14009522403217045 and parameters: {'threshold': 0.0685779359873888}. Best is trial 2 with value: 0.14009522403217045.
[I 2024-11-15 12:23:18,854] Trial 3 finished with value: 0.1356693635464132 and parameters: {'threshold': 0.039378210596882415}. Best is trial 2 with value: 0.14009522403217045.
[I 2024-11-15 12:24:42,218] Trial 4 finished with value: 0.06868484167998071 and parameters: {'threshold': 0.49377808229476455}. Best is trial 2 with value: 0.14009522403217045.
[I 2024-11-15 12:26:08,820] Trial 5 finished with value: 0.09163015933119781 and parameters: {'threshold': 0.336543

Number of finished trials: 30
Best trial:
  Value: 0.14236974954841775
  Params:
    threshold: 0.0643745388060965

Final metrics with optimal threshold:
{'roc_auc': 0.6655015347600818, 'f1_macro': 0.2432331354267383, 'f1_micro': 0.3853672158614912, 'precision_macro': 0.19573195389133513, 'precision_micro': 0.27957860615883307, 'recall_macro': 0.4035620689419316, 'recall_micro': 0.6199460916442049}
