<a href="https://colab.research.google.com/github/JoonYoung-Sohn/Protein_seq_prediction/blob/main/BPE_tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import argparse
import collections
import os
import random
import re
import shutil
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
# google drive mount
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
fasta_file = "/content/drive/MyDrive/preprocessed_data.fasta"

In [4]:
args = {
    # random seed value
    "seed": 1234,
    # CPU 또는 GPU 사용여부 결정
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    # 말뭉치 파일
    "corpus": fasta_file,
}
args = argparse.Namespace(**args)

print(args)

Namespace(corpus='/content/drive/MyDrive/preprocessed_data.fasta', device=device(type='cuda'), seed=1234)


In [5]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [6]:
word_counter = collections.defaultdict(int)

In [11]:
with open('/content/drive/MyDrive/preprocessed_data.fasta') as f:
    for i, line in enumerate(f):
        line = line.strip()
        for w in line.split():
            word_counter[w] += 1

In [12]:
print(len(word_counter))

20325


In [13]:
bpe_counter = collections.defaultdict(int)
for w, n in word_counter.items():
    w = f"\u2581{w}"
    bpe_counter[" ".join(w)] = n

print(bpe_counter)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [14]:
def update_vocab(vocab, counter):
    """
    vocab 변경
    :param vocab: vocabulary
    :param counter: BPE counter
    """
    for w in counter:
        for s in w.split():
            if s not in vocab:
                vocab[s] = len(vocab)
    return vocab

In [16]:
# bpe 일련번호 부여
bpe_to_id = {'[PAD]': 0, '[UNK]': 1}
bpe_to_id = update_vocab(bpe_to_id, bpe_counter)

print(bpe_to_id)

{'[PAD]': 0, '[UNK]': 1, '▁': 2, '/': 3, 'c': 4, 'o': 5, 'n': 6, 't': 7, 'e': 8, 'd': 9, 'r': 10, 'i': 11, 'v': 12, 'M': 13, 'y': 14, 'D': 15, 'p': 16, 's': 17, '_': 18, 'a': 19, '.': 20, 'f': 21, 'W': 22, 'L': 23, 'S': 24, 'P': 25, 'E': 26, 'V': 27, 'A': 28, 'N': 29, 'T': 30, 'R': 31, 'F': 32, 'Q': 33, 'G': 34, 'H': 35, 'Y': 36, 'I': 37, 'C': 38, 'K': 39, 'U': 40}


In [17]:
def get_stats(counter):
    """
    bi-gram 빈도수 계산
    :param counter: BPE counter
    """
    pairs = collections.defaultdict(int)
    for word, freq in counter.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

In [18]:
def merge_vocab(pair, v_in):
    """
    bi-gram merge
    :param counter: BPE counter
    """
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [19]:
pairs = get_stats(bpe_counter)

print(pairs)

defaultdict(<class 'int'>, {('▁', '/'): 1, ('/', 'c'): 1, ('c', 'o'): 1, ('o', 'n'): 1, ('n', 't'): 2, ('t', 'e'): 1, ('e', 'n'): 1, ('t', '/'): 1, ('/', 'd'): 1, ('d', 'r'): 1, ('r', 'i'): 2, ('i', 'v'): 2, ('v', 'e'): 2, ('e', '/'): 2, ('/', 'M'): 1, ('M', 'y'): 1, ('y', 'D'): 1, ('D', 'r'): 1, ('/', 'p'): 1, ('p', 'r'): 2, ('r', 'e'): 1, ('e', 'p'): 1, ('r', 'o'): 1, ('o', 'c'): 1, ('c', 'e'): 1, ('e', 's'): 1, ('s', 's'): 1, ('s', 'e'): 1, ('e', 'd'): 1, ('d', '_'): 1, ('_', 'd'): 1, ('d', 'a'): 1, ('a', 't'): 1, ('t', 'a'): 2, ('a', '.'): 1, ('.', 'f'): 1, ('f', 'a'): 1, ('a', 's'): 1, ('s', 't'): 1, ('▁', 'M'): 20315, ('M', 'W'): 2762, ('W', 'L'): 13980, ('L', 'S'): 91350, ('S', 'P'): 67887, ('P', 'E'): 52114, ('E', 'E'): 90589, ('E', 'V'): 46951, ('V', 'L'): 70895, ('L', 'V'): 63313, ('V', 'A'): 48572, ('A', 'N'): 22292, ('N', 'A'): 22764, ('A', 'L'): 82688, ('L', 'W'): 13345, ('W', 'V'): 8276, ('V', 'T'): 42836, ('T', 'E'): 39281, ('E', 'R'): 45238, ('R', 'A'): 45187, ('N', 'P'

In [20]:
# find most freq bigram pair
best = max(pairs, key=pairs.get)

print(best)

('L', 'L')


In [21]:
# merge most freq bigram pair
bpe_counter = merge_vocab(best, bpe_counter)

print(bpe_counter)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [23]:
# update vocab
bpe_to_id = update_vocab(bpe_to_id, bpe_counter)

print(len(bpe_to_id))

42


In [24]:
pairs = get_stats(bpe_counter)
best = max(pairs, key=pairs.get)
bpe_counter = merge_vocab(best, bpe_counter)
bpe_to_id = update_vocab(bpe_to_id, bpe_counter)
print(len(bpe_to_id))

43


In [25]:
while len(bpe_to_id) < 100:
  pairs = get_stats(bpe_counter)
  best = max(pairs, key=pairs.get)
  bpe_counter = merge_vocab(best, bpe_counter)
  bpe_to_id = update_vocab(bpe_to_id, bpe_counter)
print(len(bpe_to_id))
print(bpe_to_id)

100
{'[PAD]': 0, '[UNK]': 1, '▁': 2, '/': 3, 'c': 4, 'o': 5, 'n': 6, 't': 7, 'e': 8, 'd': 9, 'r': 10, 'i': 11, 'v': 12, 'M': 13, 'y': 14, 'D': 15, 'p': 16, 's': 17, '_': 18, 'a': 19, '.': 20, 'f': 21, 'W': 22, 'L': 23, 'S': 24, 'P': 25, 'E': 26, 'V': 27, 'A': 28, 'N': 29, 'T': 30, 'R': 31, 'F': 32, 'Q': 33, 'G': 34, 'H': 35, 'Y': 36, 'I': 37, 'C': 38, 'K': 39, 'U': 40, 'LL': 41, 'SS': 42, 'EE': 43, 'AA': 44, 'SL': 45, 'PP': 46, 'AL': 47, 'VL': 48, 'GL': 49, 'EL': 50, 'EK': 51, 'SP': 52, 'GG': 53, 'TL': 54, 'RL': 55, 'DL': 56, 'SG': 57, 'KK': 58, 'QL': 59, 'IL': 60, 'RR': 61, 'PL': 62, 'SA': 63, 'SV': 64, 'PG': 65, 'EA': 66, 'FL': 67, 'ST': 68, 'KL': 69, 'ED': 70, 'SR': 71, 'EV': 72, 'NL': 73, 'AV': 74, 'EG': 75, 'SQ': 76, 'TV': 77, 'AG': 78, 'SD': 79, 'SK': 80, 'TG': 81, 'ER': 82, 'PV': 83, 'PA': 84, 'HL': 85, 'SI': 86, 'EN': 87, 'QQ': 88, 'SF': 89, 'PR': 90, 'EI': 91, 'VV': 92, 'YL': 93, 'TA': 94, 'DG': 95, 'KA': 96, 'EQ': 97, 'ET': 98, 'KV': 99}
