From a216d8ee7c49731b9d5faedc4657a7740f546c74 Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 2 Oct 2025 17:38:09 +0100 Subject: [PATCH 1/9] CU-869aprnhg: Fix internal structures for peft models --- .../components/addons/meta_cat/models.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/models.py b/medcat-v2/medcat/components/addons/meta_cat/models.py index 5a7ff6600..98a2eda0c 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/models.py +++ b/medcat-v2/medcat/components/addons/meta_cat/models.py @@ -2,7 +2,7 @@ from collections import OrderedDict from typing import Optional, Any, Iterable from torch import nn, Tensor -from transformers import BertModel, AutoConfig +from transformers import BertModel, AutoConfig, PretrainedConfig from medcat.config.config_meta_cat import ConfigMetaCAT import logging logger = logging.getLogger(__name__) @@ -95,6 +95,17 @@ def forward(self, return x6 +class MetaCATHFConfig(PretrainedConfig): + """This class provides a gap between ConfigMetaCAT and the HF config. + + Some parts expects a HG config that has dict-like capabilities. + But the local ones we use here are just pydantic models now. + So this class just wraps the values into a HF-specific config. + """ + def __init__(self, meta_config: ConfigMetaCAT): + super().__init__(**meta_config.model_dump()) + + class BertForMetaAnnotation(nn.Module): _keys_to_ignore_on_load_unexpected: list[str] = [r"pooler"] # type: ignore @@ -143,7 +154,8 @@ def __init__(self, config: ConfigMetaCAT, "DO NOT use this model without loading the model state!", exc_info=e) - self.config = config + self._config = config + self.config = MetaCATHFConfig(config) self.bert = bert self.bert_config = _bertconfig self.num_labels = config.model.nclasses @@ -252,14 +264,14 @@ def forward( x = self.fc1(x) x = self.relu(x) - if self.config.model.model_architecture_config is not None: - if self.config.model.model_architecture_config['fc2'] is True: + if self._config.model.model_architecture_config is not None: + if self._config.model.model_architecture_config['fc2'] is True: # fc2 x = self.fc2(x) x = self.relu(x) x = self.dropout(x) - if self.config.model.model_architecture_config['fc3'] is True: + if self._config.model.model_architecture_config['fc3'] is True: # fc3 x = self.fc3(x) x = self.relu(x) From 91719af0b9da4860ca3559f0d35b49bdb22bd53a Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 2 Oct 2025 17:39:41 +0100 Subject: [PATCH 2/9] CU-869aprnhg: Add more tests from 148 --- .../addons/meta_cat/test_meta_cat2.py | 80 ++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py index 93d89a876..bad3f8ee2 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py @@ -2,14 +2,18 @@ import shutil import unittest from typing import cast +import tempfile from transformers import AutoTokenizer -from medcat.components.addons.meta_cat import MetaCAT +from medcat.components.addons.meta_cat import MetaCAT, MetaCATAddon from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.components.addons.meta_cat.mctokenizers.bert_tokenizer import ( TokenizerWrapperBERT) from medcat.storage.serialisers import deserialise, serialise +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.vocab import Vocab import spacy from spacy.tokens import Span @@ -115,5 +119,79 @@ def test_two_phase(self): self.meta_cat.config.model.phase_number = 0 +class MetaCATInCATTests(unittest.TestCase): + META_CAT_JSON_PATH = os.path.join( + RESOURCES_PATH, "mct_export_for_meta_cat_full_text.json") + + @classmethod + def _get_meta_cat(cls, meta_cat_dir: str): + config = ConfigMetaCAT() + config.general.category_name = "Status" + config.general.category_value2id = {'Other': 0, 'Confirmed': 1} + config.train.auto_save_model = False + config.model.model_name = 'bert' + config.model.model_freeze_layers = False + config.model.num_layers = 10 + config.train.lr = 0.001 + config.train.nepochs = 20 + config.train.class_weights = [0.75, 0.3] + config.train.metric['base'] = 'macro avg' + + meta_cat = MetaCAT(tokenizer=TokenizerWrapperBERT( + AutoTokenizer.from_pretrained("bert-base-uncased")), + embeddings=None, + config=config) + os.makedirs(meta_cat_dir, exist_ok=True) + json_path = cls.META_CAT_JSON_PATH + meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir) + return meta_cat + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(RESOURCES_PATH, "cdb_meta.zip")) + cls.vocab = Vocab.load(os.path.join(RESOURCES_PATH, "vocab_meta.zip")) + cls.vocab.init_cumsums() + cls._temp_logs_folder = tempfile.TemporaryDirectory() + cls.temp_dir = tempfile.TemporaryDirectory() + cls.cdb.config.general.nlp.modelname = "en_core_web_md" + cls.cdb.config.components.ner.min_name_len = 2 + cls.cdb.config.components.ner.upper_case_limit_len = 3 + cls.cdb.config.general.spell_check = True + cls.cdb.config.components.linking.train_count_threshold = 10 + cls.cdb.config.components.linking.similarity_threshold = 0.3 + cls.cdb.config.components.linking.train = True + cls.cdb.config.components.linking.disamb_length_limit = 5 + cls.cdb.config.general.full_unlink = True + cls.cdb.config.general.usage_monitor.enabled = True + cls.cdb.config.general.usage_monitor.log_folder = ( + cls._temp_logs_folder.name) + cls.meta_cat_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "tmp") + cls.meta_cat = cls._get_meta_cat(cls.meta_cat_dir) + cls.cat = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab) + addon = MetaCATAddon( + cls.meta_cat.config, cls.cat.pipe.tokenizer, cls.meta_cat) + cls.cat.add_addon(addon) + + @classmethod + def tearDownClass(cls) -> None: + if os.path.exists(cls.meta_cat_dir): + shutil.rmtree(cls.meta_cat_dir) + cls._temp_logs_folder.cleanup() + cls.temp_dir.cleanup() + + def test_meta_cat_through_cat(self): + text = ("This information is just to add text. The patient " + "denied history of heartburn and/or gastroesophageal " + "reflux disorder. He recently had a stroke in the last week.") + entities = self.cat.get_entities(text) + meta_status_values = [] + for en in entities['entities']: + meta_status_values.append(entities['entities'][en][ + 'meta_anns']['Status']['value']) + + self.assertEqual(meta_status_values, ['Other', 'Other', 'Confirmed']) + + if __name__ == '__main__': unittest.main() From b8cd25c869fb0cb8ed5e0931a1fd4989f6d81aaf Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 2 Oct 2025 17:40:06 +0100 Subject: [PATCH 3/9] CU-869aprnhg: Add relevant resources --- medcat-v2/tests/resources/cdb_meta.csv | 5 + medcat-v2/tests/resources/cdb_meta.zip | Bin 0 -> 13607 bytes .../mct_export_for_meta_cat_full_text.json | 10671 ++++++++++++++++ medcat-v2/tests/resources/vocab_meta.zip | Bin 0 -> 988 bytes 4 files changed, 10676 insertions(+) create mode 100644 medcat-v2/tests/resources/cdb_meta.csv create mode 100644 medcat-v2/tests/resources/cdb_meta.zip create mode 100644 medcat-v2/tests/resources/mct_export_for_meta_cat_full_text.json create mode 100644 medcat-v2/tests/resources/vocab_meta.zip diff --git a/medcat-v2/tests/resources/cdb_meta.csv b/medcat-v2/tests/resources/cdb_meta.csv new file mode 100644 index 000000000..e0752c5d3 --- /dev/null +++ b/medcat-v2/tests/resources/cdb_meta.csv @@ -0,0 +1,5 @@ +cui,name,ontologies,name_status,type_ids,description +C0000039,"gastroesophageal reflux",,,T234, +C0000239,"heartburn",,,, +C0000339,"hypertension",,,, +C0000439,"stroke",,,, \ No newline at end of file diff --git a/medcat-v2/tests/resources/cdb_meta.zip b/medcat-v2/tests/resources/cdb_meta.zip new file mode 100644 index 0000000000000000000000000000000000000000..5805013a1e2152c847fd373ee1ec3f1e85205571 GIT binary patch literal 13607 zcmcIq2{_bU)E6@?N|v&uL>ODPk_y?%z6{2aWM3QGh*4;kO&6pW~`kwE*Jr~c@|NQPbzjN-n_nv$8wdo++Y4Gjw>7yrq|FQIA z1Mu4p)&u_u0s2;;i+eeEB5hoeUJmwFw!YdeZ7u6)mPSsOHLbwrmv?VCku@2LDr0(VC*NTeOg+A^- zAI6*M>3;Lgz#^l5KbP$>vbz7#W_6k#EmV+x1V+ADaw3>5y9IM5fA8ZX=wGJ^Of z*YoVmn~d0u#*B3XzWS9<(#CYYbV1%fpUoMBvZIwnlAe1rh&7%o)bG0Xt+*8$w?3H{ z!ErCPnL*@`6f$&=nt}L_Tf=vyATP`J92b>BgwmXwi4ZoE40;Ofz&uL+_3Flz!oBhZ zBB9b8_x(ukIPYFuB2D(=wn7$fUdgD=bS`usMBX{*_NDXnB@36coz%+2 zH%&YH&^0L~x$|$>5+`w3Y-np|80XDS+N~A~**&)(N*BW#FYahG2urYr@EI%)n^+8o(ipy6_!^sN&mZ_+e)7Ls4YJ^55s-5eRQ*V-J_)y{h zi?8HD_MEC92Oo2BKtxlpoLmnF=T4Xg|MeKvdQACiKlz0dJHx*5e~WT-n_zr04Mon0 zgrB@NoF^z`5-%zx&wxWWc6+d?ULuHZT{5o zwoqIk_bGb+R&QhG1s1cDvuekOjmG`Pi_s;XbPvKu%`rRc(bY=QEddc(b(&>XeP67W zhh55<*?1)&+*E*$4?A%$yHgHsHsa1+DEFHtB5<+CI&CeYaNSPbc1=-nlM;mfV_jWM zxM-&~T$9$$iCqxJ1S9qk43JtOgZCkJ;2 zPa9W|kMU$Jnc(JtwgLH)mH4HHr-O$l%Fe;d3+WD)GDt*L&m3Yr5H~qs3{{|SBaXFp zPquexx9VInXt&nVR96<3k||B;10{)Wdm^{$yfxs++96Dh~OeBM-cG#vT<^- zaznWz(I}!TDRU?l(bb8`wWQ1Z42-F=j)rC@<=CVdR`<#iEjneU5C(`X3!^fxCG>=; z#Ftxlm?l}n&Qz4jP5jvR%x*Hn$LNFILR~wAam>))*hu=P^sfsk7HK2L%q%$-g@RG1 z?CXb(bAqq1X?&|g>V?roGLGFt{z|IOEl+0I0u7OH;irjiJYKCHd0|7@-jvDxHUS%l zvL`1DeXz8g*;UNdnsvS}Ts~w8yoSvWJi!$BJ9sJv7+4iF zTouOiqO-XCPcTtBcn_t2@WHgYlbth*gL47l0ibhJ2O)K|Sky`SM#=3e(6`@?z6!DW zSL0~P7S6s`2}hh+A?&domC`+SwzXDX9a~1Fef`y%5)Dta3!(6RL#KR;NGNl8u{w;-M_lSa%;{poLk4Y?Q*zF zhV0F|bPEiu`+FJghdxf_9AX*FD$g}65z1v_pAi->k2`hQ#U5JHZd$ZcB-epH@6Q}s z_RZga85d~By#S8jN~uh(#z<*FybFc&@N0fSxRcY46?WRe5_0v{WI zx3+0UO9*M{w%$taYBDiPOTNyjm~&~bprX8jn4%!8{FD}B_L~HRJeifB6vr)&0qSV4 zMXkT4Ih|9@5WkL!aJ^6)Rr~y`gyav-CO=l$olylgHN9e`f+|-s zBlE_GOFuPtYP=8WnSw|>iWI-n__f65bHBP4a#z$<^@|-H;jC`LN4Rf&3mU&cKPn{F zdgxon__%>jRq^A4_L7lfD)pP^s?T3)*fo}p)^LHDN1l`!jO|DC+;^=}Nq9XwWxaqX zQ`UwEmF!aU={du|+K%BVk25yDjBzaq3Mf80m-Z-Tf?JCjqW@(g`w??)wJ3)%*9_xq z)cUQ_?^89#?r)7}J+w80VRXdf-EN+j4q+)yx}%d4-Por#b&V%q9=b7LeOM`I+jxTF zjj55-nv=V|BeVXIdros8zI4*|&P(s>aoawh=Zbu#A7>cO8>^J(9+QOD&Mey4BYTD+ zx%x8HYp|i14$(rXnJj)&QxV~3@@*l-Mop&ZXKk>?$h15_A)NU|{Lj@;eP(y}_KxeM+(qoS#&O`%#5)|^s=9d3 zqpSGgLy%*9<{AvSJN~hjpEgJ;yr0fWr!cn5x&KaB-P3R*@elOaRF;ak0}M^-H{DdC zt>2ujQkzyB-ocP6t-#ZFEXT!oXHd-JzPM5=Z@cJhv-F*;kT2H)hg*{S%^>lC2&oM{ zr;oB_x#=<~Ckiv}>S;G({z-02vXG`v zB+W)!zM4QnZ|*$JqocRwyJmX%p;^ZL&99|%Fs5zgGTWf;%&w@=={5;;b;icc{g<+l zhC{hk84baD(9IlyVRBYkxvw;9FXLjWN^)!#f&{K>dAa+0+OZ065SBS`gG+6Xe^!>F z_kOoBv;fyfd$#BfcrR1pY+Rdy-%r1DA$ic2ufHCo#Jlj9(DGCin^=EFJy9!N4MCJbq!ZOeQ?EDF09f= ztX*5=u;iHv?15Y?Boz~j$jrN+db78gnJqDEFe(z==ke95?fQl=Cr-ymyR$sL`5{JQ z#gW-%wwk-;y!G##7U<^h$L()OGrIY9knMfuG1``5Eh}RV9(!HYU$WZ5TWW(}RGfTh z$lGuE!sY{;?mNxdm=g(;KRV-X4t+#caD}WBd0#2K?|X3Ac9p>RvBbVAy7d>6MN{|!EKAh`;1{;_?K&2)OUI;bqmqmmtDE&L z(^vf}uXokvBh>pOrR!gZGHx?}2sfNj|9HzUC(e=e9cIB6i*)tL%_=g!q+1%8F3g7A z>TT}7({>k(Wq8L^xLiD4D%<2hYT31x!;VqUSdc19>ywl zpDMHwX@87;V|13ogo6c}?eR_DGQJ5!qnm%x3x9(q+a zqpjV**7X?9y4fd;>70qbtu@^0fW;g)+3tKs*!v8y1f&2Rd?_GZ+qfc?>6;|CKAC8VXNni}QOluu0bZ-n=EB zJo>V+e5P_ms0yt6AGT>fgmN(3DYkEIHn+t0(Df@?+^}pbM7f@zm8r71F&9(&@u0Q1 z@!0Iobp*{pH#2s`v$!kL-G%fdOex1kT>MDnfW!Hp4G^Ua$p(l-25sZy1U?o&)v?eh z7YBDMcNB2Maz*-)B2b=Uc&H+*zT9))y62<^$eICtH&qPMWgfp7?pWLmSH@3!Z62JC zL?9w_AJEQEz{1DIFur4keh}@cQp4%N>8VK1ZC?9t-nVV@RdSG@kiq!A!Rg8eV2mWWSb}USDB1Q0Tnb?RE0Fmz)M7g!ovJL=)7-%FRXd_9*O~c4@KP_c-OB^ zDC}wuZvQgQS7Fw%@30iY`@dxA_M!RQfrHQJHr~>DIb4x2i59~P(x|!~U){m`WMxbK+9wwa7SAGo&gY5X&LG*)l8{K4Q>#0qdX86# z4=@1+bdoGa1&h=vK%|Q_<0IYg@YUae&f;a5H&c{{%Y|JVNzf7QrX!19XPROZzLW`H zb(p_qtep_W?bdb6&&|h4&vBGVVLzY?W&Ba5mA@lVa*e3}XLPQ0!2d5|Df2w_=|EUh z9FeYQ2Tx+Xm(nrRW11kp26ud+lXwW!7CUHHifIBCsT14~_hPmwGYuE9fAZA4b#F)& z$IrvtwP~Y3(K~iC`Pb6dyXI86-^^-hgt!zm2?Qwa8Kt`%`-o}BpkRLSC9aVC{O6xv z%v#o`*8Dm@+;#BHH9=o1(@)6*oLY4PnZugR@p;{<&vOKI2K$)r<~^nPs^aWJlY9Pz zdHXp|-SqAk4UydSF_sFtz5=(xJvd*|tCU@b9tkdTJ0Ne|E!t{(U9fs+>x+5 zy<<)nL?dqszg@^{-t?{{vR~l!Y}NOf%qv%3M^}A6h&%LfrW?9v7&)^`SEl2N{Gl%A zkp^LLrQmQ;Ozf_2tm|XZ_d6^M6cRWp-IHlcqqq$7i{FZLm>Dr2amgs(-5X$e>==;-y$kV}|o#&~%*_5^Oat%5j0i_A1QedV-n==2wh+nWB{{S#4BO>sx+N*f}X zd{qplCj9wjM_Sw89+4{IAE8g4zB|BhqpJ-rb^;jgEM z0xkMR1oFf4u<76H0A)r3QiNgC<+?7g&4F^$)jip-nQ z8yi@+Z_sdWyqA_CB9J!6k1Tzki01WG1IE!^^LYPTI^8icumV-G4T&!J8c0%JCSE_S zPNLZdv087yAU~j!++z?S$T&+`;>>3nSnb}oA|fNCrl%wF9?*uGUii^u5u7_yB)d>4 z`(0P_(?EyV6XBL93AmVssFrApL`qkasZ@@7zL4~xdQs`!Ldhw|io;}cHKvsgg99q1$%3(H~1Socd_ENo6*8)eVg{xE6iA-j_ko6gt<`hBCr z?0T4|O^MhPGbStjlq7RnHpl0wT}`K;2}~L-S2|iH)(vL8dXUz3!|3vDNx`qe@@Qn`^&}5h>#KOKb!{(vZE_vv=jU$dD&7-)HJ(!UveNEx|gFSp{siw^7 z@#*F%sVT>45vi8(`?VSP}TXCg63w5R3i`W zCP6@tB{xxC<*x238`g9+{EfgVpyT(3D?*XEi?em{0OS)tJ*|b^K(Pc zZp&m-14qC0I{K1=^+(EuWQq?j*Ft>dSB$W%gQ56epmC zH?#p7fGYR{8K_Es&2o&d?>sCVhMD*MF9!zAfi2JJZi_^(quoxo@@o@7C!jIWe8LyF z32fm+q2v$8a}fj>#TKrTuClz=GN4s4J}$s$TpZtCz) z0~Q;aE>H`2&bH(!0#B58@Z?;P-@g;#u26>q zGm`wKn+WE)aRuFo&y~nEBfS&F-(3?LF7Oj2*z$k?_bK@;HWB#9CJHT=1A`e)ehW>6 zeYP4bb$?Ul2Yh98vEe6{1#r2z=E!vSZH{IZ##5ial<84eg( z986!M@5b-g>1Z>kNSpts@@=8Ae5Ep1f6%8~d$V&A@thZFL z);bf2yy3010CohFR#KMXiJn;Gjl2X;)GlX5L@;Z~3*kgO6cBVPX-F;lDgC;9?ti3; zM~FCZ|B~0aiCP|Hr=}$V70hz-qBao|3yQf`f8z_(ga#ZDfG0DQY_3+WCigxm5BO>n zq2U70tr35%j4Vf=SR*=hZYXmRzQ(-R@VV$Ba5|ynhL9Htbv1G>NH-$OBdrb;Z3RY&AsQKEm55BHQXt=4;MT8sxgIw-uC~)Sap-}+-lLRKWhUli7N@4fCslf%JadwfApk|k zKiMAPa3%(Zi_8oRJU|^qiRJMrnaL%3DTyV$ll(J;9Yu~!&M~|7rC)B*+oisa6C(^= z_=N&yI&bW94iY)oJS}yz^;;J6-Gyb#mNmUN7@^I~)z{87kL6%Gv+Wkg(zoR^BxiiDdtd$eSOS_FXRt{2 zy$w3VqvgEwl{?qIgJ%Di9d?K;VJl7wP)QKbN^O+hk#(t2>64Jt-b}-i(g2z1+qwAV z1a9Vh%LqLcpU}{CCm`M1pm^uT-xsSltbCmF`g!*GgDc+}$9_C+->^W*J5?{7xBX^F z$`8?X&vJ#QrbHeSz5MdC)lsc`GY)*+aq@GMxxRJgv^QVgUHWlSI{x5a&Zh^tclvk4 zo{y8C_5b*>TuJTT#|9@Vis$P^mri&svhiMf$?CZa;<~4Xtrh&e-~G;}q;}=prz?|{ zixqx-%=?&7BDbgZ@yqaDZi(q?PCeHXmk7^FJ`%Sz$IB_1pWlC8e=SpS=jH41=l-#6 zZ@2Zm=2Y&n;!p+iw9g76H}mdA&iC5p^TyNQ=iJm?9T|@{7FyLc3x7MdAb5h68_oDAPPaVCmO+f*te0@FrbS`S2_V@F-#LYa>%Zkfu zX5~!PU96q1stTqTH*N){aYiN)X59G#Xb%`HX#`QIxddG=dJci;WMEj*Xar;;=NELH g=w3tU+ziZd$Q}*wW(8$i24*1K0Hno$xsibZ00Aa{^#A|> literal 0 HcmV?d00001 From 00cce7c7f698e0b63323ad8fad3cb4c4ea11b8bf Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 2 Oct 2025 21:19:42 +0100 Subject: [PATCH 4/9] CU-869aprnhg: Allow an extra 10 minutes for tests in workflow --- .github/workflows/medcat-v2_main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/medcat-v2_main.yml b/.github/workflows/medcat-v2_main.yml index 23e46ad3e..a617e866a 100644 --- a/.github/workflows/medcat-v2_main.yml +++ b/.github/workflows/medcat-v2_main.yml @@ -38,7 +38,7 @@ jobs: uv run ruff check medcat --preview - name: Test run: | - timeout 20m uv run python -m unittest discover + timeout 30m uv run python -m unittest discover - name: Model regression run: | uv run bash tests/backwards_compatibility/run_current.sh @@ -46,4 +46,4 @@ jobs: run: | uv run bash tests/backwards_compatibility/check_backwards_compatibility.sh - name: Minimize uv cache - run: uv cache prune --ci \ No newline at end of file + run: uv cache prune --ci From 77b403714a863ba3d542f3a6cd9203f9c244ee36 Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 3 Oct 2025 09:31:49 +0100 Subject: [PATCH 5/9] CU-869aprnhg: Move to a cleaner approach for the bert/pretrained config fix --- .../medcat/components/addons/meta_cat/models.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/models.py b/medcat-v2/medcat/components/addons/meta_cat/models.py index 98a2eda0c..ff834ee65 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/models.py +++ b/medcat-v2/medcat/components/addons/meta_cat/models.py @@ -2,7 +2,7 @@ from collections import OrderedDict from typing import Optional, Any, Iterable from torch import nn, Tensor -from transformers import BertModel, AutoConfig, PretrainedConfig +from transformers import BertModel, AutoConfig from medcat.config.config_meta_cat import ConfigMetaCAT import logging logger = logging.getLogger(__name__) @@ -95,17 +95,6 @@ def forward(self, return x6 -class MetaCATHFConfig(PretrainedConfig): - """This class provides a gap between ConfigMetaCAT and the HF config. - - Some parts expects a HG config that has dict-like capabilities. - But the local ones we use here are just pydantic models now. - So this class just wraps the values into a HF-specific config. - """ - def __init__(self, meta_config: ConfigMetaCAT): - super().__init__(**meta_config.model_dump()) - - class BertForMetaAnnotation(nn.Module): _keys_to_ignore_on_load_unexpected: list[str] = [r"pooler"] # type: ignore @@ -155,9 +144,9 @@ def __init__(self, config: ConfigMetaCAT, exc_info=e) self._config = config - self.config = MetaCATHFConfig(config) self.bert = bert - self.bert_config = _bertconfig + # NOTE: potentially used downstream + self.config = _bertconfig self.num_labels = config.model.nclasses for param in self.bert.parameters(): param.requires_grad = not config.model.model_freeze_layers From 8276003e843f9ff64c7061c511daad8db600a03f Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 3 Oct 2025 09:58:28 +0100 Subject: [PATCH 6/9] CU-869aprnhg: Make sure bert_config attribute also exists --- medcat-v2/medcat/components/addons/meta_cat/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/models.py b/medcat-v2/medcat/components/addons/meta_cat/models.py index ff834ee65..c1a417431 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/models.py +++ b/medcat-v2/medcat/components/addons/meta_cat/models.py @@ -146,7 +146,7 @@ def __init__(self, config: ConfigMetaCAT, self._config = config self.bert = bert # NOTE: potentially used downstream - self.config = _bertconfig + self.config = self.bert_config = _bertconfig self.num_labels = config.model.nclasses for param in self.bert.parameters(): param.requires_grad = not config.model.model_freeze_layers From 238f4a85d2499c9ab4ba99dd0c4392a3237b6120 Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 3 Oct 2025 11:18:09 +0100 Subject: [PATCH 7/9] CU-869aprnhg: Use spacy for MetaCAT tests --- medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py index bad3f8ee2..9fdf7c0b2 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat2.py @@ -153,6 +153,7 @@ def setUpClass(cls) -> None: cls.vocab.init_cumsums() cls._temp_logs_folder = tempfile.TemporaryDirectory() cls.temp_dir = tempfile.TemporaryDirectory() + cls.cdb.config.general.nlp.provider = "spacy" cls.cdb.config.general.nlp.modelname = "en_core_web_md" cls.cdb.config.components.ner.min_name_len = 2 cls.cdb.config.components.ner.upper_case_limit_len = 3 From 4a6d11075f6aca92e73b4e5a406b20a7d996a343 Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 3 Oct 2025 12:35:22 +0100 Subject: [PATCH 8/9] CU-869aprnhg: Port changes from v1 PR 155 --- .../components/addons/meta_cat/data_utils.py | 243 ++++++++++++------ .../components/addons/meta_cat/meta_cat.py | 5 +- 2 files changed, 161 insertions(+), 87 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 6e410a75a..96dd2828a 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -3,6 +3,7 @@ from medcat.components.addons.meta_cat.mctokenizers.tokenizers import ( TokenizerWrapperBase) +from medcat.config.config_meta_cat import ConfigMetaCAT import logging logger = logging.getLogger(__name__) @@ -209,10 +210,124 @@ def prepare_for_oversampled_data(data: list, return data_sampled +def find_alternate_classname(category_value2id: dict, category_values: set, + alternative_class_names: list[list]) -> dict: + """Find and map to alternative class names for the given category. + + Example: + For Temporality category, 'Recent' is an alternative to 'Present'. + + Args: + category_value2id (dict): + The pre-defined category_value2id + category_values (set): + Contains the classes (labels) found in the data + alternative_class_names (list[list]): + Contains the mapping of alternative class names + + Returns: + category_value2id (dict): + Updated category_value2id with keys corresponding to + alternative class names + + Raises: + Exception: If no alternatives are found for labels in + category_value2id that don't match any of the labels in + the data + Exception: If the alternatives defined for labels in + category_value2id that don't match any of the labels in + the data + """ + + updated_category_value2id = {} + for _class in category_value2id.keys(): + if _class in category_values: + updated_category_value2id[_class] = category_value2id[_class] + else: + found_in = [sub_map for sub_map in alternative_class_names + if _class in sub_map] + failed_to_find = False + if len(found_in) != 0: + class_name_matched = [label for label in found_in[0] + if label in category_values] + if len(class_name_matched) != 0: + updated_category_value2id[ + class_name_matched[0]] = category_value2id[_class] + logger.info( + "Class name '%s' does not exist in the data; however " + "a variation of it '%s' is present; updating it...", + _class, class_name_matched[0]) + else: + failed_to_find = True + else: + failed_to_find = True + if failed_to_find: + raise Exception( + "The classes set in the config are not the same as the " + "one found in the data. The classes present in the config " + "vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. " + "Additionally, ensure the populate the " + "'alternative_class_names' attribute to accommodate for " + "variations.") + category_value2id = copy.deepcopy(updated_category_value2id) + logger.info("Updated categoryvalue2id mapping - %s", category_value2id) + return category_value2id + + +def undersample_data(data: list, category_value2id: dict, label_data_, + config: ConfigMetaCAT) -> list: + """Undersamples the data for 2 phase learning + + Args: + data (list): + Output of `prepare_from_json`. + category_value2id(dict): + Map from category_value to id. + label_data_: + Map that stores the number of samples for each label + config: + MetaCAT config + + Returns: + data_undersampled (list): + Return the data created for 2 phase learning) with integers + inplace of strings for category values + """ + + data_undersampled = [] + category_undersample = config.model.category_undersample + if category_undersample is None or category_undersample == '': + min_label = min(label_data_.values()) + + else: + if (category_undersample not in label_data_.keys() and + category_undersample in category_value2id.keys()): + min_label = label_data_[category_value2id[category_undersample]] + else: + min_label = label_data_[category_undersample] + + label_data_counter = {v: 0 for v in category_value2id.values()} + + for sample in data: + if label_data_counter[sample[-1]] < min_label: + data_undersampled.append(sample) + label_data_counter[sample[-1]] += 1 + + label_data = {v: 0 for v in category_value2id.values()} + for i in range(len(data_undersampled)): + if data_undersampled[i][2] in category_value2id.values(): + label_data[data_undersampled[i][2]] = ( + label_data[data_undersampled[i][2]] + 1) + logger.info("Updated number of samples per label (for 2-phase learning):" + " %s", label_data) + return data_undersampled + + def encode_category_values(data: list[tuple[list, list, str]], existing_category_value2id: Optional[dict] = None, - category_undersample: Optional[str] = None, - alternative_class_names: list[list[str]] = [] + alternative_class_names: list[list[str]] = [], + config: Optional[ConfigMetaCAT] = None, ) -> tuple[ list[tuple[list, list, str]], list, dict]: """Converts the category values in the data outputted by @@ -223,13 +338,12 @@ def encode_category_values(data: list[tuple[list, list, str]], Output of `prepare_from_json`. existing_category_value2id(Optional[dict]): Map from category_value to id (old/existing). - category_undersample (Optional[str]): - Name of class that should be used to undersample the data (for 2 - phase learning) alternative_class_names (list[list[str]]): A list of lists of strings, where each list contains variations of a class name. Usually read from the config at `config.general.alternative_class_names`. + config (Optional[ConfigMetaCAT]): + The MetaCAT Config. Returns: list[tuple[list, list, str]]: @@ -252,98 +366,59 @@ def encode_category_values(data: list[tuple[list, list, str]], category_values = set([x[2] for x in data_list]) - if (len(category_value2id) != 0 and - set(category_value2id.keys()) != category_values): - # if categoryvalue2id doesn't match the labels in the data, - # then 'alternative_class_names' has to be defined to check - # for variations - if len(alternative_class_names) == 0: - # Raise an exception since the labels don't match + if config: + if len(category_values) != config.model.nclasses: raise Exception( - "The classes set in the config are not the same as the one " - "found in the data. The classes present in the config vs the " - "ones found in the data - {set(category_value2id.keys())}, " - f"{category_values}. Additionally, ensure the populate the " - "'alternative_class_names' attribute to accommodate for " - "variations.") - updated_category_value2id = {} - for _class in category_value2id.keys(): - if _class in category_values: - updated_category_value2id[_class] = category_value2id[_class] - else: - found_in = [sub_map for sub_map in alternative_class_names - if _class in sub_map] - failed_to_find = False - if len(found_in) != 0: - class_name_matched = [label for label in found_in[0] - if label in category_values] - if len(class_name_matched) != 0: - updated_category_value2id[class_name_matched[0] - ] = category_value2id[_class] - logger.info( - "Class name '%s' does not exist in the data; " - "however a variation of it '%s' is present; " - "updating it...", _class, class_name_matched[0]) - else: - failed_to_find = True - else: - failed_to_find = True - if failed_to_find: - raise Exception( - "The classes set in the config are not the same as " - "the one found in the data. The classes present in " - "the config vs the ones found in the data - " - f"{set(category_value2id.keys())}, {category_values}. " - "Additionally, ensure the populate the " - "'alternative_class_names' attribute to accommodate " - "for variations.") - category_value2id = copy.deepcopy(updated_category_value2id) - logger.info("Updated categoryvalue2id mapping - %s", category_value2id) + "The number of classes found in the data - %s does not match " + "the number of classes defined in the config - %s " + "(config.model.nclasses). Please update the number of classes " + "and initialise the model again.", len(category_values), + config.model.nclasses) + # If categoryvalue2id is pre-defined or if all the classes aren't mentioned + if len(category_value2id) != 0: + # making sure it is same as the labels found in the data + if set(category_value2id.keys()) != category_values: + # if categoryvalue2id doesn't match the labels in the data, + # then 'alternative_class_names' has to be defined to check for + # variations + if len(alternative_class_names) == 0: + # Raise an exception since the labels don't match + raise Exception( + "The classes set in the config are not the same as the " + "one found in the data. The classes present in the config " + "vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. " + "Additionally, ensure the populate the " + "'alternative_class_names' attribute to accommodate for " + "variations.") + + category_value2id = find_alternate_classname( + category_value2id, category_values, alternative_class_names) + # Else create the mapping from the labels found in the data - else: + if len(category_value2id) != len(category_values): for c in category_values: if c not in category_value2id: category_value2id[c] = len(category_value2id) - logger.info("Categoryvalue2id mapping created with labels found " - "in the data - %s", category_value2id) + logger.info("Categoryvalue2id mapping created with labels found in " + "the data - %s", category_value2id) # Map values to numbers - for i in range(len(data_list)): - # NOTE: internally, it's a a list so assingment will work - data_list[i][2] = category_value2id[data_list[i][2]] # type: ignore + for i in range(len(data)): + # represented as a tuple so that we can type hint, but it's a list + data[i][2] = category_value2id[data[i][2]] # type: ignore # Creating dict with labels and its number of samples label_data_ = {v: 0 for v in category_value2id.values()} - for i in range(len(data_list)): - if data_list[i][2] in category_value2id.values(): - label_data_[data_list[i][2]] = label_data_[data_list[i][2]] + 1 + for i in range(len(data)): + if data[i][2] in category_value2id.values(): + label_data_[data[i][2]] = label_data_[data[i][2]] + 1 logger.info("Original number of samples per label: %s", label_data_) - # Undersampling data - if category_undersample is None or category_undersample == '': - min_label = min(label_data_.values()) - - else: - if (category_undersample not in label_data_.keys() and - category_undersample in category_value2id.keys()): - min_label = label_data_[category_value2id[category_undersample]] - else: - min_label = label_data_[category_undersample] data_undersampled = [] - label_data_counter = {v: 0 for v in category_value2id.values()} - - for sample in data_list: - if label_data_counter[sample[-1]] < min_label: - data_undersampled.append(sample) - label_data_counter[sample[-1]] += 1 - - label_data = {v: 0 for v in category_value2id.values()} - for i in range(len(data_undersampled)): - if data_undersampled[i][2] in category_value2id.values(): - label_data[data_undersampled[i][2]] = label_data[ - data_undersampled[i][2]] + 1 - logger.info("Updated number of samples per label (for 2-phase learning): " - "%s", label_data) + if config and config.model.phase_number != 0: + data_undersampled = undersample_data( + data, category_value2id, label_data_, config) return data_list, data_undersampled, category_value2id diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 2fdd9c893..5386a8dc1 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -556,15 +556,14 @@ def train_raw(self, data_loaded: dict, save_dir_path: Optional[str] = None, # Encode the category values (full_data, data_undersampled, category_value2id) = encode_category_values( - data, - category_undersample=self.config.model.category_undersample, + data, config=self.config, alternative_class_names=g_config.alternative_class_names) else: # We already have everything, just get the data (full_data, data_undersampled, category_value2id) = encode_category_values( data, existing_category_value2id=category_value2id, - category_undersample=self.config.model.category_undersample, + config=self.config, alternative_class_names=g_config.alternative_class_names) g_config.category_value2id = category_value2id self.config.model.nclasses = len(category_value2id) From 3b2eed4f609eb89bb0ef6426ee8848b508245b9a Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 3 Oct 2025 12:37:38 +0100 Subject: [PATCH 9/9] CU-869aprnhg: Add some improved typing --- medcat-v2/medcat/components/addons/meta_cat/data_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 96dd2828a..caf568b9f 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -210,8 +210,8 @@ def prepare_for_oversampled_data(data: list, return data_sampled -def find_alternate_classname(category_value2id: dict, category_values: set, - alternative_class_names: list[list]) -> dict: +def find_alternate_classname(category_value2id: dict, category_values: set[str], + alternative_class_names: list[list[str]]) -> dict: """Find and map to alternative class names for the given category. Example: @@ -220,9 +220,9 @@ def find_alternate_classname(category_value2id: dict, category_values: set, Args: category_value2id (dict): The pre-defined category_value2id - category_values (set): + category_values (set[str]): Contains the classes (labels) found in the data - alternative_class_names (list[list]): + alternative_class_names (list[list[str]]): Contains the mapping of alternative class names Returns: