Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multigpu support & UI to test embeddings #109

Merged
merged 12 commits into from
Jan 11, 2024
Merged

Add multigpu support & UI to test embeddings #109

merged 12 commits into from
Jan 11, 2024

Conversation

srmsoumya
Copy link
Collaborator

  • Update MEAN & STD for 10% of the data used for CLAY v0 model training
  • Add streamlit UI with support for vector search & arithmetic
  • Fix issue with tensors on different cuda devices when doing multi-gpu training

@srmsoumya srmsoumya requested a review from weiji14 January 5, 2024 05:56
Copy link
Contributor

@weiji14 weiji14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy New Year @srmsoumya! Just some comments to help you pass the CI tests and linter warnings, plus a few optional recommendations/notes.

src/datamodule.py Outdated Show resolved Hide resolved
clayground.py Show resolved Hide resolved
trainer.py Outdated Show resolved Hide resolved
src/model_clay.py Outdated Show resolved Hide resolved
Comment on lines -46 to +56
# LearningRateMonitor(logging_interval="step"),
# LogIntermediatePredictions(),
LearningRateMonitor(logging_interval="step"),
LogIntermediatePredictions(),
],
"logger": False, # WandbLogger(project="CLAY-v0", log_model=False)
"logger": [WandbLogger(project="CLAY-v0", log_model=False)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to add these loggers for now, but I'd like to disable them by default so that new users don't need to install and setup wandb and such on the first run. We could potentially use LightningCLI's YAML config files - https://lightning.ai/docs/pytorch/2.0.2/cli/lightning_cli_advanced.html to store some of these advanced configurations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another reason is that turning on the WandbLogger by default here also means wandb is invoked during inference/prediction, and we would need to do something like WANDB_MODE=disabled python trainer.py predict ... to not use wandb at all.

src/model_clay.py Outdated Show resolved Hide resolved
Copy link
Contributor

@weiji14 weiji14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, just one more thing to fix and should be good to merge!

clayground.py Outdated
Comment on lines 25 to 30
filter = " OR ".join(
[
f"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}"
for chip in chips
]
)
Copy link
Contributor

@weiji14 weiji14 Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to wrap L27 to <88 characters to make the linter happy.

Suggested change
filter = " OR ".join(
[
f"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}"
for chip in chips
]
)
filter = " OR ".join(
[
f"(tile == '{chip['tile']}' "
f"AND idx == '{chip['idx']}') "
f"AND year == {chip['year']}"
for chip in chips
]
)

Copy link
Contributor

@weiji14 weiji14 Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way @srmsoumya, filter is a built-in Python function, so best to avoid it. Can we change the variable name to something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I just went ahead and renamed filter to tile_filter at commit f4731c6

Comment on lines +37 to +39
def connect_to_database():
db = lancedb.connect("nbs/embeddings")
tbl = db.open_table("clay-v001")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to document how the `nbs/embeddings' LanceDB vector database is created. Can do this afterwards.

Avoid the use of a variable name that is the same as Python's built-in filter function. Also wrap the where clause statement to under 88 characters.
@weiji14 weiji14 added the model-architecture Pull requests about the neural network model architecture label Jan 11, 2024
@weiji14 weiji14 added this to the v0 Release milestone Jan 11, 2024
@weiji14 weiji14 enabled auto-merge (squash) January 11, 2024 21:55
@weiji14 weiji14 merged commit de1556b into main Jan 11, 2024
2 checks passed
@weiji14 weiji14 deleted the ddp branch January 11, 2024 21:56
@weiji14 weiji14 mentioned this pull request Jan 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model-architecture Pull requests about the neural network model architecture
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants