Skip to content
This repository has been archived by the owner on May 14, 2024. It is now read-only.

Resume training not working, and a simple solution #322

Open
TsGrolken opened this issue Dec 16, 2023 · 0 comments
Open

Resume training not working, and a simple solution #322

TsGrolken opened this issue Dec 16, 2023 · 0 comments

Comments

@TsGrolken
Copy link

when resuming from a checkpoint, it returns:
load network weights from /content/drive/MyDrive/XXXXX.safetensors: None

I have slightly looked into the code, and find that it always provides a 'FALSE' value for the 'dtype' parameter in the 'load_weights' function, while you always feed the correct 'dtype' when saving the checkpoint. And there is a simple fix for it:

in flie 'kohya-trainer/train_network.py', starting from line 206, you can replace

    if args.network_weights is not None:
        info = network.load_weights(args.network_weights)
        print(f"load network weights from {args.network_weights}: {info}")

with:

    def load_weights_2(network, file, dtype):
        if os.path.splitext(file)[1] == ".safetensors":
            from safetensors.torch import load_file

            weights_sd = load_file(file)
        else:
            weights_sd = torch.load(file, map_location="cpu")

        info = network.load_state_dict(weights_sd, dtype)
        return info

    if args.network_weights is not None:
        info = load_weights_2(network, args.network_weights, save_dtype)
        print(f"load network weights from {args.network_weights}: {info}")

This should work for lora, locon and loha, not sure if it works for XL, and it can be easily done in Google Colab editing mode.
Hope it can be fixed officially soon.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant