Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD japanese-reranker-cross-encoder-large-v1 #1443

Open
kyakuno opened this issue Apr 4, 2024 · 1 comment
Open

ADD japanese-reranker-cross-encoder-large-v1 #1443

kyakuno opened this issue Apr 4, 2024 · 1 comment
Assignees

Comments

@kyakuno
Copy link
Collaborator

kyakuno commented Apr 4, 2024

モデル:https://huggingface.co/hotchpotch/japanese-reranker-cross-encoder-large-v1
Rerankerの使い方:https://note.com/npaka/n/n906b23636ac8?sub_rt=share_h
Rerankerの概要:https://secon.dev/entry/2024/04/02/070000-japanese-reranker-release/
CrossEncoderについて:https://qiita.com/warper/items/fd84e740e62ad1a67703
mit

@ooe1123
Copy link
Contributor

ooe1123 commented May 27, 2024

class Exp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.activation = Sigmoid()

    def forward(self, input_ids, attention_mask, token_type_ids):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }
        logits = self.model(**inputs).logits
        scores = self.activation(logits)
        return scores


if 1:
    with torch.no_grad():
        print("------>")
        from torch.autograd import Variable

        model = Exp(model)
        xx = (
            Variable(inputs["input_ids"]),
            Variable(inputs["attention_mask"]),
            Variable(inputs["token_type_ids"]),
        )
        torch.onnx.export(
            model,
            xx,
            "xxx.onnx",
            input_names=["input_ids", "attention_mask", "token_type_ids"],
            output_names=["scores"],
            dynamic_axes={
                "input_ids": [0, 1],
                "attention_mask": [0, 1],
                "token_type_ids": [0, 1],
                "scores": [0],
            },
            verbose=False,
            opset_version=17,
        )
        print("<------")
        1 / 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants