Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow extra_epochs flag in Trainer.fit to control finetuning time #13273

Open
franchesoni opened this issue Jun 12, 2022 · 7 comments
Open

Allow extra_epochs flag in Trainer.fit to control finetuning time #13273

franchesoni opened this issue Jun 12, 2022 · 7 comments
Labels
question Further information is requested trainer: argument

Comments

@franchesoni
Copy link

franchesoni commented Jun 12, 2022

馃殌 Feature

Trainer(max_epochs=100).fit(model, train_dl, ckpt_path=ckpt_path, extra_epochs=True) would finetune for 100 epochs

Motivation

Finetuning for N epochs requires knowing the previous number of epochs M and setting Trainer(max_epochs=M+N). Google did not tell me how to achieve this.

Pitch

Finetuning training time or number of epochs should be configurable.

Alternatives

Setting many epochs and manually stopping

Additional context

It would be cool with max_time too. I hope this is already solved and this issue is unnecessary.


cc @justusschock @kaushikb11 @awaelchli @Borda @rohitgr7

@franchesoni franchesoni added the needs triage Waiting to be triaged by maintainers label Jun 12, 2022
@carmocca
Copy link
Contributor

You accomplish this by doing:

trainer.fit_loop.max_epochs += 100

before trainer.fit() is called

@carmocca carmocca added question Further information is requested trainer: argument and removed needs triage Waiting to be triaged by maintainers labels Jun 14, 2022
@franchesoni
Copy link
Author

franchesoni commented Jun 19, 2022

If this worked, it would be very counter intuitive, because the current number of epochs is known only after calling .fit(..., ckpt_path=ckpt_path)
I think you are assuming I'm loading the model first, which is not the case, as I'm using the ckpt_path argument

When trying your solution

trainer = pl.Trainer(**trainer_params)
trainer.fit_loop.max_epochs += 2
trainer.fit(model, train_dl, val_dl, ckpt_path=best_ckpt)

I don't find the desired behavior

cc @justusschock @kaushikb11 @awaelchli @Borda @rohitgr7

@carmocca carmocca reopened this Jun 22, 2022
@franchesoni
Copy link
Author

Hello, any further news or alternative answer?

@awaelchli
Copy link
Member

awaelchli commented Jul 24, 2022

@franchesoni If I understand correctly, you are saying this is not an option for you?

model = Model.load_from_checkpoint("path/to/pretrained/checkpoint.ckpt")
trainer = pl.Trainer(**trainer_params, max_epochs=N) 
trainer.fit(model, train_dl, val_dl)

I assume because you want some parts of the trainer state restored from the checkpoint, e.g. optimizer state, but not the full loop state.

Then I think this is just an other version of this request #5339 to be able to control what is getting restored. I think this is something we need to start adding to the roadmap and think hard about.

@carmocca
Copy link
Contributor

There are 2 potential solutions:

  1. Pre-load the checkpoint manually
ckpt = torch.load(...)
current_epoch = ckpt["current_epoch"]
trainer = Trainer(max_epochs=current_epoch + N)

An issue with this method is that it loads the fully checkpoint just for this change. This relates to #5339 and #12712

  1. Extract the state from the checkpoint in on_load_checkpoint and modify the Trainer's max_epochs. This requires editing the LightningModule hook to do this or creating a Callback just for it.

@stale
Copy link

stale bot commented Apr 15, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 15, 2023
@JIAOJIAYUASD
Copy link

JIAOJIAYUASD commented Jul 26, 2023

There are 2 potential solutions:

  1. Pre-load the checkpoint manually
ckpt = torch.load(...)
current_epoch = ckpt["current_epoch"]
trainer = Trainer(max_epochs=current_epoch + N)

An issue with this method is that it loads the fully checkpoint just for this change. This relates to #5339 and #12712

  1. Extract the state from the checkpoint in on_load_checkpoint and modify the Trainer's max_epochs. This requires editing the LightningModule hook to do this or creating a Callback just for it.

Hello, I got a image inpainting project Paint-by-Example implemented in pytorch_lightning. I want to finetune the stable diffusion model using LoRA, but I can't find the model definition and don't know how to add lora finetuning process to the project. Can you give me some advice?

@stale stale bot removed the won't fix This will not be worked on label Jul 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested trainer: argument
Projects
None yet
Development

No branches or pull requests

4 participants