Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Error(s) in loading state_dict for CLAP: Unexpected key(s) in state_dict: "text_branch.embeddings.position_ids #127

Closed
rjac-ml opened this issue Sep 25, 2023 · 4 comments

Comments

@rjac-ml
Copy link

rjac-ml commented Sep 25, 2023

I was running the following code into Colab

import laion_clap
import glob
import json
import torch
import numpy as np

device = torch.device('cuda:0')

# download https://drive.google.com/drive/folders/1scyH43eQAcrBz-5fAw44C6RNBhC3ejvX?usp=sharing and extract ./ESC50_1/test/0.tar to ./ESC50_1/test/
esc50_test_dir = './ESC50_1/test/*/'
class_index_dict_path = './class_labels/ESC50_class_labels_indices_space.json'

# Load the model
model = laion_clap.CLAP_Module(enable_fusion=False, device=device)
model.load_ckpt()

# Get the class index dict
class_index_dict = {v: k for v, k in json.load(open(class_index_dict_path)).items()}

# Get all the data
audio_files = sorted(glob.glob(esc50_test_dir + '**/*.flac', recursive=True))
json_files = sorted(glob.glob(esc50_test_dir + '**/*.json', recursive=True))
ground_truth_idx = [class_index_dict[json.load(open(jf))['tag'][0]] for jf in json_files]

with torch.no_grad():
    ground_truth = torch.tensor(ground_truth_idx).view(-1, 1)

    # Get text features
    all_texts = ["This is a sound of " + t for t in class_index_dict.keys()]
    text_embed = model.get_text_embedding(all_texts)
    audio_embed = model.get_audio_embedding_from_filelist(x=audio_files)

    ranking = torch.argsort(torch.tensor(audio_embed) @ torch.tensor(text_embed).t(), descending=True)
    preds = torch.where(ranking == ground_truth)[1]
    preds = preds.cpu().numpy()

    metrics = {}
    metrics[f"mean_rank"] = preds.mean() + 1
    metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
    for k in [1, 5, 10]:
        metrics[f"R@{k}"] = np.mean(preds < k)
    # map@10
    metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))

    print(
        f"Zeroshot Classification Results: "
        + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
    )

I got thte error t the model.load_ckpt()

python==3.10
torch==2.0.1+cu118

error

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Load our best checkpoint in the paper.
The checkpoint is already downloaded
Load Checkpoint...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-4-53d44b376007>](https://localhost:8080/#) in <cell line: 16>()
     14 # Load the model
     15 model = laion_clap.CLAP_Module(enable_fusion=False, device=device)
---> 16 model.load_ckpt()
     17 
     18 # Get the class index dict

1 frames
[/usr/local/lib/python3.10/dist-packages/laion_clap/hook.py](https://localhost:8080/#) in load_ckpt(self, ckpt, model_id)
    112         print('Load Checkpoint...')
    113         ckpt = load_state_dict(ckpt, skip_params=True)
--> 114         self.model.load_state_dict(ckpt)
    115         param_names = [n for n, p in self.model.named_parameters()]
    116         for n in param_names:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for CLAP:
	Unexpected key(s) in state_dict: "text_branch.embeddings.position_ids".
@fatchord
Copy link

fatchord commented Sep 27, 2023

I get the same error with this code

import laion_clap
from pathlib import Path

audiopaths = list(Path("test_wavs/").glob("*.wav"))

model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base")
model.load_ckpt("checkpoints/music_speech_epoch_15_esc_89.25.pt")
audio_embed = model.get_audio_embedding_from_filelist(x=audiopaths, use_tensor=False)
print(audio_embed[:, -20:])
print(audio_embed.shape)

versions:

python==3.10
torch==2.0.1+cu117

error:

/home/fc/miniconda3/envs/clap/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Load the specified checkpoint checkpoints/music_speech_epoch_15_esc_89.25.pt from users.
Load Checkpoint...
Traceback (most recent call last):
  File "/home/ollie/CLAP/test.py", line 7, in <module>
    model.load_ckpt("checkpoints/music_speech_epoch_15_esc_89.25.pt")
  File "/home/ollie/miniconda3/envs/clap/lib/python3.10/site-packages/laion_clap/hook.py", line 114, in load_ckpt
    self.model.load_state_dict(ckpt)
  File "/home/ollie/miniconda3/envs/clap/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLAP:
        Unexpected key(s) in state_dict: "text_branch.embeddings.position_ids".
        ```

@fatchord
Copy link

ah i see it's related to this: #118
i rolled back transformers to v4.30.0 and it seems to be working ok now

@flu0r1ne
Copy link

Here is a quick and dirty script to remove the offending key from the checkpoint:

import argparse
import os
import torch

OFFENDING_KEY = "module.text_branch.embeddings.position_ids"

def main(args):
    # Load the checkpoint from the given path
    checkpoint = torch.load(
        args.input_checkpoint, map_location="cpu"
    )

    # Extract the state_dict from the checkpoint
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    # Delete the specific key from the state_dict
    if OFFENDING_KEY in state_dict:
        del state_dict[OFFENDING_KEY]

    # Save the modified state_dict back to the checkpoint
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        checkpoint["state_dict"] = state_dict

    # Create the output checkpoint filename by replacing the ".pt" suffix with ".patched.pt"
    output_checkpoint_path = args.input_checkpoint.replace('.pt', '.patched.pt')

    # Save the modified checkpoint
    torch.save(checkpoint, output_checkpoint_path)
    print(f"Saved patched checkpoint to {output_checkpoint_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Patch a PyTorch checkpoint by removing a specific key.')
    parser.add_argument('input_checkpoint', type=str, help='Path to the input PyTorch checkpoint (.pt) file.')

    try:
        import argcomplete
        argcomplete.autocomplete(parser)
    except ImportError:
        pass

    args = parser.parse_args()
    main(args)
python remove_key.py 630k-audioset-fusion-best.pt

Then, use 630k-audioset-fusion-best.patched.pt in load_ckpt. I don't know if the text_branch.embeddings.position_ids parameters are used elsewhere. I've verified this works with the basic examples.

@RetroCirce
Copy link
Contributor

This problem is fixed! We push the 1.1.6 pypi laion-clap to fix some bugs: https://pypi.org/project/laion-clap/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants