Conversation
d168f99 to
be4603d
Compare
|
Need to do a couple test runs to make sure. |
pzhanggit
left a comment
There was a problem hiding this comment.
Thanks @aprokop The PR looks good. See my comments. I realize we might have some redundant coding that can be improved now. Also, could you run some small training runs for sanity checking? I wonder if it's time for us to write the belated ci tests
matey/models/basemodel.py
Outdated
| # patch_ids: [npatches] #selected token ids with sample pos inside batch considered | ||
| # t_pos_area: [B, T, ntoken_len_tot, 5] | ||
| """ | ||
| assert conditioning == False or refineind is None |
There was a problem hiding this comment.
Could we remove conditioning == False?
matey/models/basemodel.py
Outdated
| sys.exit(-1) | ||
|
|
||
| def get_unified_preembedding(self, x, state_labels, ilevel=0): | ||
| def get_unified_preembedding(self, x, state_labels, ilevel=0, conditioning: bool = False): |
There was a problem hiding this comment.
Would it be more concise if we pass directly the module? Something like (...,op=self.space_bag[ilevel],...) and (...,op=self.space_bag_cond[ilevel]...)?
matey/models/basemodel.py
Outdated
| return x | ||
|
|
||
| def get_structured_sequence(self, x, embed_index, tkhead_name, ilevel=0): | ||
| def get_structured_sequence(self, x, embed_index, tkhead_name, ilevel=0, conditioning: bool = False): |
There was a problem hiding this comment.
Similar here. Could we condense the three arguments tkhead_name, ilevel=0, conditioning: bool = False into one via passing self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed"][embed_index] or self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed_cond"][embed_index] directly?
Working on it, will run together with #9. |
|
Am able to run, need to fix couple things. |
Thanks, Andrey. Are the loss history consistent? Let me know when it's ready for review. |
|
Running using Flow3D dataset (on 48^3 patch) One caveat: when I say "this branch with n_states_cond=0", it requires minor change to this branch (as I can only test it together with #9): --- a/matey/train.py
+++ b/matey/train.py
@@ -462,7 +462,7 @@ class Trainer:
with record_function_opt("model forward", enabled=self.profiling):
output= self.model(inp, field_labels, bcs, imod=imod,
sequence_parallel_group=self.current_group, leadtime=leadtime,
- refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict) #, cond_dict=cond_dict)
+ refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict, cond_dict=cond_dict)
###full resolution###
spatial_dims = tuple(range(output.ndim))[2:] # B,C,D,H,W
residuals = output - tar
@@ -603,7 +603,7 @@ class Trainer:
imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0
output= self.model(inp, field_labels, bcs, imod=imod,
sequence_parallel_group=self.current_group, leadtime=leadtime,
- refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict) #, cond_dict=cond_dict)
+ refineind=refineind, tkhead_name=tkhead_name, blockdict=blockdict, cond_dict=cond_dict)
#################################
###full resolution###
spatial_dims = tuple(range(output.ndim))[2:]VIT
|
TsChala
left a comment
There was a problem hiding this comment.
I went through this, looks good to me.
No description provided.