-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multigpu support & UI to test embeddings #109
Conversation
srmsoumya
commented
Jan 4, 2024
- Update MEAN & STD for 10% of the data used for CLAY v0 model training
- Add streamlit UI with support for vector search & arithmetic
- Fix issue with tensors on different cuda devices when doing multi-gpu training
UI supports vector search & arithmetic for embeddings generated from CLAY. --------- Co-authored-by: SRM <soumya@developmentseed.org>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy New Year @srmsoumya! Just some comments to help you pass the CI tests and linter warnings, plus a few optional recommendations/notes.
# LearningRateMonitor(logging_interval="step"), | ||
# LogIntermediatePredictions(), | ||
LearningRateMonitor(logging_interval="step"), | ||
LogIntermediatePredictions(), | ||
], | ||
"logger": False, # WandbLogger(project="CLAY-v0", log_model=False) | ||
"logger": [WandbLogger(project="CLAY-v0", log_model=False)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok to add these loggers for now, but I'd like to disable them by default so that new users don't need to install and setup wandb
and such on the first run. We could potentially use LightningCLI's YAML config files - https://lightning.ai/docs/pytorch/2.0.2/cli/lightning_cli_advanced.html to store some of these advanced configurations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another reason is that turning on the WandbLogger by default here also means wandb
is invoked during inference/prediction, and we would need to do something like WANDB_MODE=disabled python trainer.py predict ...
to not use wandb at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, just one more thing to fix and should be good to merge!
clayground.py
Outdated
filter = " OR ".join( | ||
[ | ||
f"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}" | ||
for chip in chips | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just need to wrap L27 to <88 characters to make the linter happy.
filter = " OR ".join( | |
[ | |
f"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}" | |
for chip in chips | |
] | |
) | |
filter = " OR ".join( | |
[ | |
f"(tile == '{chip['tile']}' " | |
f"AND idx == '{chip['idx']}') " | |
f"AND year == {chip['year']}" | |
for chip in chips | |
] | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way @srmsoumya, filter
is a built-in Python function, so best to avoid it. Can we change the variable name to something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I just went ahead and renamed filter
to tile_filter
at commit f4731c6
def connect_to_database(): | ||
db = lancedb.connect("nbs/embeddings") | ||
tbl = db.open_table("clay-v001") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to document how the `nbs/embeddings' LanceDB vector database is created. Can do this afterwards.
Avoid the use of a variable name that is the same as Python's built-in filter function. Also wrap the where clause statement to under 88 characters.