# AnglE-optimized Text Embeddings

> It is Angle 📐, not Angel 👼.

🔥 基于 AnglE 开箱即用的文本向量库，支持中英双语，可用于文本相似度计算、检索召回、匹配等场景。代码基于 🤗transformers 构建，提供易用的微调接口，可在 3090Ti、 4090 等消费级 GPU 上微调 LLaMA-7B 模型，支持多卡分布式训练。


<a href="https://arxiv.org/abs/2309.12871">
    <img src="https://img.shields.io/badge/Arxiv-2306.06843-yellow.svg?style=flat-square" alt="https://arxiv.org/abs/2309.12871" />
</a>
<a href="https://pypi.org/project/angle_emb/">
    <img src="https://img.shields.io/pypi/v/angle_emb?style=flat-square" alt="PyPI version" />
</a>

<a href="https://pypi.org/project/angle_emb/">
    <img src="https://img.shields.io/pypi/dm/angle_emb?style=flat-square" alt="PyPI Downloads" />
</a>
<a href="http://makeapullrequest.com">
    <img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square" alt="http://makeapullrequest.com" />
</a>

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sick-r-1)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sick-r-1?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts16)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts16?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts15)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts15?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts14)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts14?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts13)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts13?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts12)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts12?p=angle-optimized-text-embeddings)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts-benchmark)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts-benchmark?p=angle-optimized-text-embeddings)



如果你有使用我们的代码及预训练模型，欢迎给我们三连，三连方式为：
1) 给本项目 GitHub 加个 star
2) 粘贴以下引用信息到你 paper 的 bibtex
3) 在你的 paper 正文中引用

```bibtex
@article{li2023angle,
  title={AnglE-Optimized Text Embeddings},
  author={Li, Xianming and Li, Jing},
  journal={arXiv preprint arXiv:2309.12871},
  year={2023}
}
```

# 1. 安装依赖

In [2]:
!pip install -U angle-emb

Defaulting to user installation because normal site-packages is not writeable
Collecting angle-emb
  Downloading angle_emb-0.1.1-py3-none-any.whl (12 kB)
Collecting transformers>=4.32.1
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
     |████████████████████████████████| 7.7 MB 23.4 MB/s            
Collecting bitsandbytes
  Using cached bitsandbytes-0.41.1-py3-none-any.whl (92.6 MB)
Collecting tokenizers<0.15,>=0.14
  Downloading tokenizers-0.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
     |████████████████████████████████| 3.8 MB 70.0 MB/s            
Collecting huggingface-hub<1.0,>=0.16.4
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
     |████████████████████████████████| 301 kB 76.6 MB/s            
  Using cached huggingface_hub-0.17.3-py3-none-any.whl (295 kB)


Installing collected packages: huggingface-hub, tokenizers, transformers, bitsandbytes, angle-emb
  Attempting uninstall: huggingface-hub
    Found existing installation: huggingface-hub 0.15.1
    Uninstalling huggingface-hub-0.15.1:
      Successfully uninstalled huggingface-hub-0.15.1
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.13.3
    Uninstalling tokenizers-0.13.3:
      Successfully uninstalled tokenizers-0.13.3
  Attempting uninstall: transformers
    Found existing installation: transformers 4.29.2
    Uninstalling transformers-4.29.2:
      Successfully uninstalled transformers-4.29.2
Successfully installed angle-emb-0.1.1 bitsandbytes-0.41.1 huggingface-hub-0.17.3 tokenizers-0.14.1 transformers-4.34.1


In [1]:
import os
import random
import numpy as np
import torch


os.environ['CUDA_VISIBLE_DEVICES'] = '0'

seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fb2e68b4f50>

# 2. 加载数据

需要封装成 datasets.Dataset 格式，必须包含 `text1`, `text2`, `label` 三列，`label` 列是数值类型。

In [2]:
from datasets import load_dataset

ds = load_dataset('shibing624/nli_zh', 'ATEC')
ds = ds.rename_column('sentence1', 'text1')
ds = ds.rename_column('sentence2', 'text2')
ds = ds.select_columns(["text1", "text2", "label"])

Found cached dataset nli_zh (/home/jupyter-sean/.cache/huggingface/datasets/shibing624___nli_zh/ATEC/1.0.0/65b555276ee420c801e1c9eb830db959e37f42fa60c68c8b07a4448b8c436706)


  0%|          | 0/3 [00:00<?, ?it/s]

# 3. 加载模型训练

参数主要调整 loss_kwargs，请大家搜索参数，各参数含义参照 Paper: https://arxiv.org/abs/2309.12871


In [None]:
from angle_emb import AnglE

angle = AnglE.from_pretrained('hfl/chinese-roberta-wwm-ext', max_length=128, pooling_strategy='cls').cuda()

train_ds = ds['train'].shuffle()
valid_ds = ds['validation']
test_ds = ds['test']

Loading cached shuffled indices for dataset at /home/jupyter-sean/.cache/huggingface/datasets/shibing624___nli_zh/ATEC/1.0.0/65b555276ee420c801e1c9eb830db959e37f42fa60c68c8b07a4448b8c436706/cache-bc417f8e3c8845d1.arrow
Loading cached processed dataset at /home/jupyter-sean/.cache/huggingface/datasets/shibing624___nli_zh/ATEC/1.0.0/65b555276ee420c801e1c9eb830db959e37f42fa60c68c8b07a4448b8c436706/cache-f15bff99645a64a2_*_of_00008.arrow
Loading cached processed dataset at /home/jupyter-sean/.cache/huggingface/datasets/shibing624___nli_zh/ATEC/1.0.0/65b555276ee420c801e1c9eb830db959e37f42fa60c68c8b07a4448b8c436706/cache-cea1adf8fcf578ef_*_of_00008.arrow
Loading cached processed dataset at /home/jupyter-sean/.cache/huggingface/datasets/shibing624___nli_zh/ATEC/1.0.0/65b555276ee420c801e1c9eb830db959e37f42fa60c68c8b07a4448b8c436706/cache-e2aa4a5813639a58_*_of_00008.arrow


In [None]:
angle.fit(
    train_ds=train_ds,
    valid_ds=test_ds,
    output_dir='ckpts/atec',
    batch_size=64,
    epochs=5,
    learning_rate=3e-5,
    save_steps=1000,
    eval_steps=1000,
    warmup_steps=0,
    gradient_accumulation_steps=1,
    loss_kwargs={
        'w1': 1.0,
        'w2': 5.0,
        'w3': 1.0,
        'cosine_tau': 20,
        'ibn_tau': 20,
        'angle_tau': 1.0
    },
    fp16=True,
    logging_steps=500
)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,13.4152
1000,12.8942
1500,12.1724
2000,12.1944
2500,11.5382
3000,11.3915
3500,10.8896
4000,10.7976
4500,10.3419


Evaluate: 313it [00:14, 21.63it/s]


new best corrcoef!
save to ckpts/atec/best-checkpoint
corrcoef: 0.4922603377898468, accuracy: 0.8564, best corrcoef: 0.4922603377898468


Evaluate: 313it [00:14, 21.59it/s]


new best corrcoef!
save to ckpts/atec/best-checkpoint
corrcoef: 0.5103606507039081, accuracy: 0.8622, best corrcoef: 0.5103606507039081


Evaluate: 313it [00:14, 21.81it/s]


new best corrcoef!
save to ckpts/atec/best-checkpoint
corrcoef: 0.5121040463618336, accuracy: 0.8621, best corrcoef: 0.5121040463618336


Evaluate: 313it [00:14, 21.80it/s]


new best corrcoef!
save to ckpts/atec/best-checkpoint
corrcoef: 0.5130152790494402, accuracy: 0.8646, best corrcoef: 0.5130152790494402


# 4. 评估

In [None]:
# load best checkpoint and evaluate

angle = AnglE.from_pretrained('hfl/chinese-roberta-wwm-ext', pretrained_model_path='ckpts/atec/best-checkpoint').cuda()
angle.evaluate(test_ds, device=angle.device)