# ESM の埋め込みベクトルを取得する

## 1. 必要なライブラリのインストール

In [2]:
!pip install fair-esm



In [3]:
import os

# 一応環境変数を変更
os.environ["HF_HOME"] = "D:/hf-home"

In [10]:
import torch
import esm

# ESM-1b を読み込み
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

# 評価モードにする
model.eval()

# アミノ酸配列
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
seqs = [(aa, aa) for aa in amino_acids]

# トークン化
_, _, tokens = batch_converter(seqs)

# 埋め込み情報
with torch.no_grad():
    results = model(tokens, repr_layers=[33], return_contacts=False)

# トークン単位の埋め込み（33層目）
token_representations = results["representations"][33]

embedding_dict = {}

# 埋め込みは <cls>, token, <eos> の3つ
for i, aa in enumerate(amino_acids):
    vec = token_representations[i, 1, :].numpy()
    embedding_dict[aa] = vec

print(embedding_dict)

{'A': array([ 0.11339767, -0.01012399,  0.25220194, ..., -0.19726606,
       -0.02746568, -0.18330735], dtype=float32), 'C': array([-0.12331015, -0.11629812,  0.2763299 , ..., -0.21442588,
       -0.05947747, -0.21635804], dtype=float32), 'D': array([ 0.02541378, -0.15991558,  0.1884571 , ..., -0.1508156 ,
       -0.06527089, -0.11141248], dtype=float32), 'E': array([ 0.06360463, -0.04435542,  0.19132927, ..., -0.19241521,
        0.0078241 , -0.26709664], dtype=float32), 'F': array([ 0.05718051,  0.06145927,  0.24460267, ..., -0.12219335,
       -0.05262642, -0.3649092 ], dtype=float32), 'G': array([ 0.21209703, -0.093178  ,  0.25682455, ..., -0.11816211,
       -0.06181881, -0.32026413], dtype=float32), 'H': array([ 0.01737899, -0.08932362,  0.1520881 , ..., -0.17507413,
       -0.00407022, -0.26191962], dtype=float32), 'I': array([ 0.0723353 ,  0.08427574,  0.1600373 , ..., -0.13815896,
       -0.08360748, -0.34114602], dtype=float32), 'K': array([ 0.18613508,  0.24342108, -0.167378

## 2. CSV へ保存

In [22]:
import csv

save_filename = "../data/esm-embedding-dim-1280.csv"

with open(save_filename, "w", encoding="utf-8", newline="") as csv_file:
    writer = csv.writer(csv_file)

    for key in embedding_dict:
        row = [key] + embedding_dict[key].tolist()
        writer.writerow(row)