Skip to content

Update models to allow conditioning#1

Merged
pzhanggit merged 5 commits intomainfrom
xap/add_conditioning_to_models
Jan 7, 2026
Merged

Update models to allow conditioning#1
pzhanggit merged 5 commits intomainfrom
xap/add_conditioning_to_models

Conversation

@aprokop
Copy link
Copy Markdown
Collaborator

@aprokop aprokop commented Jul 15, 2025

No description provided.

@aprokop aprokop force-pushed the xap/add_conditioning_to_models branch from d168f99 to be4603d Compare November 13, 2025 22:25
@aprokop
Copy link
Copy Markdown
Collaborator Author

aprokop commented Nov 13, 2025

Need to do a couple test runs to make sure.

Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aprokop The PR looks good. See my comments. I realize we might have some redundant coding that can be improved now. Also, could you run some small training runs for sanity checking? I wonder if it's time for us to write the belated ci tests

# patch_ids: [npatches] #selected token ids with sample pos inside batch considered
# t_pos_area: [B, T, ntoken_len_tot, 5]
"""
assert conditioning == False or refineind is None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove conditioning == False?

sys.exit(-1)

def get_unified_preembedding(self, x, state_labels, ilevel=0):
def get_unified_preembedding(self, x, state_labels, ilevel=0, conditioning: bool = False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more concise if we pass directly the module? Something like (...,op=self.space_bag[ilevel],...) and (...,op=self.space_bag_cond[ilevel]...)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return x

def get_structured_sequence(self, x, embed_index, tkhead_name, ilevel=0):
def get_structured_sequence(self, x, embed_index, tkhead_name, ilevel=0, conditioning: bool = False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here. Could we condense the three arguments tkhead_name, ilevel=0, conditioning: bool = False into one via passing self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed"][embed_index] or self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed_cond"][embed_index] directly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@aprokop
Copy link
Copy Markdown
Collaborator Author

aprokop commented Dec 2, 2025

Also, could you run some small training runs for sanity checking? I wonder if it's time for us to write the belated ci tests

Working on it, will run together with #9.

@aprokop
Copy link
Copy Markdown
Collaborator Author

aprokop commented Dec 2, 2025

Am able to run, need to fix couple things.

@pzhanggit
Copy link
Copy Markdown
Collaborator

Am able to run, need to fix couple things.

Thanks, Andrey. Are the loss history consistent? Let me know when it's ready for review.

@aprokop aprokop requested a review from TsChala December 16, 2025 19:23
@aprokop
Copy link
Copy Markdown
Collaborator Author

aprokop commented Dec 16, 2025

Running using Flow3D dataset (on 48^3 patch)

One caveat: when I say "this branch with n_states_cond=0", it requires minor change to this branch (as I can only test it together with #9):

--- a/matey/train.py
+++ b/matey/train.py
@@ -462,7 +462,7 @@ class Trainer:
                 with record_function_opt("model forward", enabled=self.profiling):
                     output= self.model(inp, field_labels, bcs, imod=imod,
                                     sequence_parallel_group=self.current_group, leadtime=leadtime,
-                                    refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict) #, cond_dict=cond_dict)
+                                    refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict, cond_dict=cond_dict)
                 ###full resolution###
                 spatial_dims = tuple(range(output.ndim))[2:] # B,C,D,H,W
                 residuals = output - tar
@@ -603,7 +603,7 @@ class Trainer:
                     imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0
                     output= self.model(inp, field_labels, bcs, imod=imod,
                                        sequence_parallel_group=self.current_group, leadtime=leadtime,
-                                       refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict) #, cond_dict=cond_dict)
+                                       refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict, cond_dict=cond_dict)
                     #################################
                     ###full resolution###
                     spatial_dims = tuple(range(output.ndim))[2:]

VIT

main

Train loss: 0.6939765214920044. Valid loss: 0.486361563205719
Train loss: 0.4495169520378113. Valid loss: 0.44182640314102173
Train loss: 0.3927997946739197. Valid loss: 0.3793710172176361
Train loss: 0.3674124479293823. Valid loss: 0.35462433099746704
Train loss: 0.33511653542518616. Valid loss: 0.33494308590888977
Train loss: 0.31036806106567383. Valid loss: 0.32899996638298035
Train loss: 0.3129328191280365. Valid loss: 0.31377172470092773
Train loss: 0.2922043800354004. Valid loss: 0.31285133957862854
Train loss: 0.2900574505329132. Valid loss: 0.302879273891449
Train loss: 0.28926289081573486. Valid loss: 0.3068097233772278

This branch (with n_states_cond = 0)

Train loss: 0.6861850023269653. Valid loss: 0.4853196144104004
Train loss: 0.4494728147983551. Valid loss: 0.42999234795570374
Train loss: 0.3956623673439026. Valid loss: 0.38326168060302734
Train loss: 0.368410587310791. Valid loss: 0.35956692695617676
Train loss: 0.3387361764907837. Valid loss: 0.3376031517982483
Train loss: 0.3120647966861725. Valid loss: 0.32652899622917175
Train loss: 0.3160429298877716. Valid loss: 0.3157729506492615
Train loss: 0.29402101039886475. Valid loss: 0.31497421860694885
Train loss: 0.2919299304485321. Valid loss: 0.3046448528766632
Train loss: 0.29090601205825806. Valid loss: 0.30892127752304077

AViT

main

Train loss: 0.6840741038322449. Valid loss: 0.4846583604812622
Train loss: 0.4476206600666046. Valid loss: 0.4347662031650543
Train loss: 0.3944677710533142. Valid loss: 0.382155179977417
Train loss: 0.36767876148223877. Valid loss: 0.35942336916923523
Train loss: 0.3419545590877533. Valid loss: 0.3392013609409332
Train loss: 0.31563523411750793. Valid loss: 0.3365303874015808
Train loss: 0.31790655851364136. Valid loss: 0.31734612584114075
Train loss: 0.29747945070266724. Valid loss: 0.31693828105926514
Train loss: 0.2959226965904236. Valid loss: 0.30728742480278015
Train loss: 0.29511934518814087. Valid loss: 0.31103259325027466

This branch (with n_states_cond = 0)

Train loss: 0.6979491710662842. Valid loss: 0.48853570222854614
Train loss: 0.4496990442276001. Valid loss: 0.435271680355072
Train loss: 0.39476752281188965. Valid loss: 0.38508138060569763
Train loss: 0.36811399459838867. Valid loss: 0.3562987148761749
Train loss: 0.34042227268218994. Valid loss: 0.33719101548194885
Train loss: 0.3142472803592682. Valid loss: 0.3359690308570862
Train loss: 0.31439208984375. Valid loss: 0.31662094593048096
Train loss: 0.2964993715286255. Valid loss: 0.3158888816833496
Train loss: 0.29508739709854126. Valid loss: 0.30576395988464355
Train loss: 0.29404446482658386. Valid loss: 0.3096076250076294

Copy link
Copy Markdown
Collaborator

@TsChala TsChala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through this, looks good to me.

@pzhanggit pzhanggit merged commit 8ff7a8b into main Jan 7, 2026
@aprokop aprokop mentioned this pull request Jan 8, 2026
@aprokop aprokop deleted the xap/add_conditioning_to_models branch January 11, 2026 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants