In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
import os
from tqdm import tqdm

In [2]:
text_root = "./results/split_text"
save_root = "./text_feature"
os.makedirs(save_root, exist_ok=True)

# Load encoder (example: KoSimCSE or multilingual model)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model.eval().cuda()


MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_

## CLS 토큰

In [3]:
def get_text_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
    return embeddings.squeeze(0).cpu()

for story_name in os.listdir(text_root):
    story_path = os.path.join(text_root, story_name)
    if not os.path.isdir(story_path): continue

    save_story_dir = os.path.join(save_root, story_name)
    os.makedirs(save_story_dir, exist_ok=True)

    text_files = sorted([f for f in os.listdir(story_path) if f.endswith(".txt")])
    for f in tqdm(text_files, desc=story_name):
        with open(os.path.join(story_path, f), "r", encoding="utf-8") as file:
            text = file.read().strip()
        emb = get_text_embedding(text)
        torch.save(emb, os.path.join(save_story_dir, f.replace(".txt", ".pt")))

tunnel: 100%|██████████████████████████████████| 51/51 [00:00<00:00, 143.40it/s]
vinny: 100%|█████████████████████████████████████| 5/5 [00:00<00:00, 281.33it/s]
prettymouth: 100%|█████████████████████████████| 22/22 [00:00<00:00, 303.60it/s]
notthefallshortscram: 100%|████████████████████| 18/18 [00:00<00:00, 303.99it/s]
21styear: 100%|██████████████████████████████| 111/111 [00:00<00:00, 313.43it/s]
santa: 100%|█████████████████████████████████████| 5/5 [00:00<00:00, 291.04it/s]
milkywaysynonyms: 100%|████████████████████████| 13/13 [00:00<00:00, 316.16it/s]
shapesphysical: 100%|██████████████████████████| 13/13 [00:00<00:00, 313.83it/s]
milkywayoriginal: 100%|████████████████████████| 13/13 [00:00<00:00, 298.61it/s]
shapessocial: 100%|████████████████████████████| 13/13 [00:00<00:00, 326.73it/s]
milkywayvodka: 100%|███████████████████████████| 13/13 [00:00<00:00, 277.24it/s]
black: 100%|███████████████████████████████████| 26/26 [00:00<00:00, 317.61it/s]
friends: 100%|██████████████

In [4]:
text_root = "./results/split_text"
save_root = "./text_feature_pooling"
os.makedirs(save_root, exist_ok=True)

# Load encoder (example: KoSimCSE or multilingual model)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model.eval().cuda()


MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_

## mean Pooling

In [5]:
# Mean pooling embedding
def get_text_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        token_embeddings = outputs.last_hidden_state  # (1, seq_len, hidden_size)
        attention_mask = inputs["attention_mask"]      # (1, seq_len)

        # Expand mask to match embedding size
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()

        # Sum and average
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        mean_embeddings = sum_embeddings / sum_mask

    return mean_embeddings.squeeze(0).cpu()

# Feature 저장
for story_name in os.listdir(text_root):
    story_path = os.path.join(text_root, story_name)
    if not os.path.isdir(story_path):
        continue

    save_story_dir = os.path.join(save_root, story_name)
    os.makedirs(save_story_dir, exist_ok=True)

    text_files = sorted([f for f in os.listdir(story_path) if f.endswith(".txt")])
    for f in tqdm(text_files, desc=story_name):
        with open(os.path.join(story_path, f), "r", encoding="utf-8") as file:
            text = file.read().strip()
        emb = get_text_embedding(text)
        torch.save(emb, os.path.join(save_story_dir, f.replace(".txt", ".pt")))

tunnel: 100%|██████████████████████████████████| 51/51 [00:00<00:00, 289.95it/s]
vinny: 100%|█████████████████████████████████████| 5/5 [00:00<00:00, 285.56it/s]
prettymouth: 100%|█████████████████████████████| 22/22 [00:00<00:00, 299.07it/s]
notthefallshortscram: 100%|████████████████████| 18/18 [00:00<00:00, 309.21it/s]
21styear: 100%|██████████████████████████████| 111/111 [00:00<00:00, 310.99it/s]
santa: 100%|█████████████████████████████████████| 5/5 [00:00<00:00, 296.94it/s]
milkywaysynonyms: 100%|████████████████████████| 13/13 [00:00<00:00, 318.49it/s]
shapesphysical: 100%|██████████████████████████| 13/13 [00:00<00:00, 317.72it/s]
milkywayoriginal: 100%|████████████████████████| 13/13 [00:00<00:00, 306.56it/s]
shapessocial: 100%|████████████████████████████| 13/13 [00:00<00:00, 324.32it/s]
milkywayvodka: 100%|███████████████████████████| 13/13 [00:00<00:00, 289.33it/s]
black: 100%|███████████████████████████████████| 26/26 [00:00<00:00, 317.00it/s]
friends: 100%|██████████████