In [1]:
import re
import sys
import random
from typing import List, Dict
sys.path.append('..')
from llm.tokenizer import LBPETokenizer
from IPython.display import display, HTML

def colorize_tokens(tokens: List[str], colors: List[str]) -> str:
    """
    Colorize tokens based on provided colors, preserving spaces and newlines.
    """
    colored_text = ""
    for token, color in zip(tokens, colors):
        if token == "\n":
            colored_text += "<br>"  # Replace newlines with <br> for HTML
        else:
            colored_text += f"<span style='color: {color}; white-space: pre;'>{token}</span>"
    return colored_text

def highlight_tokens_with_map(tokens: List[str], color_map: Dict[str, str]) -> str:
    """
    Highlight tokens using a predefined color map.
    """
    highlighted_text = ""
    for token in tokens:
        if token == "\n":
            highlighted_text += "<br>"  # Replace newlines with <br> for HTML
        else:
            # Escape HTML tags
            token_display = token.replace("<", "&lt;").replace(">", "&gt;")
            color = color_map.get(token, "#FFFFFF")  # Default to white if token not found
            highlighted_text += f"<span style='background-color: {color}; white-space: pre;'>{token_display}</span>"
    return highlighted_text

def generate_color_map(tokens: List[str]) -> Dict[str, str]:
    """
    Generate a unique color for each unique token.
    """
    unique_tokens = set(tokens)
    color_map = {}
    
    for token in unique_tokens:
        color_map[token] = f"#{random.randint(0, 0xFFFFFF):06x}"
    
    return color_map



In [2]:
test = """<LIGAND>
Cn1c(=O)c2c(ncn2C)n(C)c1=O
<XYZ>
C 3.2932 0.3895 0.2537
N 2.1108 -0.4198 0.1026
C 2.0698 -1.7874 0.0333
N 0.8358 -2.2304 -0.1052
C 0.0712 -1.1032 -0.1240
C 0.8238 0.0247 0.0014
C 0.2701 1.3335 0.0113
O 0.9576 2.3440 0.1250
N -1.1216 1.3332 -0.1210
C -1.8079 2.6102 -0.1253
C -1.9352 0.1881 -0.2531
O -3.1606 0.2905 -0.3648
N -1.2957 -1.0520 -0.2516
C -2.0603 -2.2787 -0.3829
<eos>"""

In [3]:
tknz = LBPETokenizer()
tknz.train(test)
tknz.register_special_tokens(
    {
        "<LIGAND>": 257,
        "<PAD>": 258,
        "<MASK>": 259,
        "<UNK>": 260,
        "<XYZ>": 261,
        "<eos>": 262
    }
)

In [4]:
tokens = tknz.get_tokens(test)
tokens

{257, 261, 262, 10, 32, 40, 41, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 61, 67, 78, 79, 99, 110}


['<LIGAND>',
 '\n',
 'C',
 'n',
 '1',
 'c',
 '(',
 '=',
 'O',
 ')',
 'c',
 '2',
 'c',
 '(',
 'n',
 'c',
 'n',
 '2',
 'C',
 ')',
 'n',
 '(',
 'C',
 ')',
 'c',
 '1',
 '=',
 'O',
 '\n',
 '<XYZ>',
 '\n',
 'C',
 ' ',
 '3',
 '.',
 '2',
 '9',
 '3',
 '2',
 ' ',
 '0',
 '.',
 '3',
 '8',
 '9',
 '5',
 ' ',
 '0',
 '.',
 '2',
 '5',
 '3',
 '7',
 '\n',
 'N',
 ' ',
 '2',
 '.',
 '1',
 '1',
 '0',
 '8',
 ' ',
 '-',
 '0',
 '.',
 '4',
 '1',
 '9',
 '8',
 ' ',
 '0',
 '.',
 '1',
 '0',
 '2',
 '6',
 '\n',
 'C',
 ' ',
 '2',
 '.',
 '0',
 '6',
 '9',
 '8',
 ' ',
 '-',
 '1',
 '.',
 '7',
 '8',
 '7',
 '4',
 ' ',
 '0',
 '.',
 '0',
 '3',
 '3',
 '3',
 '\n',
 'N',
 ' ',
 '0',
 '.',
 '8',
 '3',
 '5',
 '8',
 ' ',
 '-',
 '2',
 '.',
 '2',
 '3',
 '0',
 '4',
 ' ',
 '-',
 '0',
 '.',
 '1',
 '0',
 '5',
 '2',
 '\n',
 'C',
 ' ',
 '0',
 '.',
 '0',
 '7',
 '1',
 '2',
 ' ',
 '-',
 '1',
 '.',
 '1',
 '0',
 '3',
 '2',
 ' ',
 '-',
 '0',
 '.',
 '1',
 '2',
 '4',
 '0',
 '\n',
 'C',
 ' ',
 '0',
 '.',
 '8',
 '2',
 '3',
 '8',
 ' ',
 '0',
 '.',
 '0

In [5]:
ids = tknz.encode("".join(tokens))
print(ids)

[257, 10, 67, 110, 49, 99, 40, 61, 79, 41, 99, 50, 99, 40, 110, 99, 110, 50, 67, 41, 110, 40, 67, 41, 99, 49, 61, 79, 10, 261, 10, 67, 32, 51, 46, 50, 57, 51, 50, 32, 48, 46, 51, 56, 57, 53, 32, 48, 46, 50, 53, 51, 55, 10, 78, 32, 50, 46, 49, 49, 48, 56, 32, 45, 48, 46, 52, 49, 57, 56, 32, 48, 46, 49, 48, 50, 54, 10, 67, 32, 50, 46, 48, 54, 57, 56, 32, 45, 49, 46, 55, 56, 55, 52, 32, 48, 46, 48, 51, 51, 51, 10, 78, 32, 48, 46, 56, 51, 53, 56, 32, 45, 50, 46, 50, 51, 48, 52, 32, 45, 48, 46, 49, 48, 53, 50, 10, 67, 32, 48, 46, 48, 55, 49, 50, 32, 45, 49, 46, 49, 48, 51, 50, 32, 45, 48, 46, 49, 50, 52, 48, 10, 67, 32, 48, 46, 56, 50, 51, 56, 32, 48, 46, 48, 50, 52, 55, 32, 48, 46, 48, 48, 49, 52, 10, 67, 32, 48, 46, 50, 55, 48, 49, 32, 49, 46, 51, 51, 51, 53, 32, 48, 46, 48, 49, 49, 51, 10, 79, 32, 48, 46, 57, 53, 55, 54, 32, 50, 46, 51, 52, 52, 48, 32, 48, 46, 49, 50, 53, 48, 10, 78, 32, 45, 49, 46, 49, 50, 49, 54, 32, 49, 46, 51, 51, 51, 50, 32, 45, 48, 46, 49, 50, 49, 48, 10, 67, 32, 4

In [6]:
# # Assign random background colors to tokens for highlighting
# colors = [f"#{random.randint(0, 0xFFFFFF):06x}" for _ in tokens]

# # Generate the colored HTML
# colored_text = highlight_tokens(tokens, colors)

# # Display the colored tokens in a notebook
# display(HTML(colored_text))

In [7]:
# Generate a unique color for each unique token
color_map = generate_color_map(tokens)

# Generate the highlighted HTML with unique colors for each token
highlighted_text = highlight_tokens_with_map(tokens, color_map)

# Display the highlighted tokens in a notebook
display(HTML(highlighted_text))

# Data Loader

In [8]:
import sys
import yaml
sys.path.append('..')
from llm.preprocessing.loader import create_data_loader, split_data

In [9]:
with open('../data/xyz_mols/cleaned_input.txt', 'r') as f:
    data = f.read()

In [10]:
config = yaml.load(open('../llm/config/train.yml', 'r'), Loader=yaml.FullLoader)

In [11]:

# tokenizer
tokenizer = LBPETokenizer()
tokenizer.train(data)
tokenizer.register_special_tokens(
    {
        "<LIGAND>": 257,
        "<PAD>": 258,
        "<MASK>": 259,
        "<UNK>": 260,
        "<XYZ>": 261,
        "<eos>": 262
    }
)
# save tokenizer
tokenizer.save("LBPETokenizer")
# split data
train_data, val_data = split_data(data)


In [12]:
tokenizer.vocab[49]

b'1'

In [13]:
# print last 500 tokens
print(train_data[-500:])

5
C 0.0914 0.9221 0.9321
C -1.2920 1.4259 0.6337
C -2.0751 0.8953 -0.3166
C -1.6368 -0.2291 -1.2081
C -0.1288 -0.4818 -1.1541
<eos>
<LIGAND>
O=C(N[C@@H]1CCC[C@H]1O)C(F)(F)F
<XYZ>
O 2.1407 1.5924 -0.5740
C 2.2813 0.5350 0.0421
N 1.1946 -0.2629 0.3991
C -0.1331 -0.0153 -0.1362
C -0.7549 1.3135 0.3004
C -2.2702 1.0842 0.2997
C -2.4799 -0.3982 0.0077
C -1.1529 -1.0425 0.3717
O -0.9733 -2.2993 -0.2593
C 3.6524 -0.0019 0.5126
F 4.6010 0.0399 -0.4617
F 3.5741 -1.3061 0.9140
F 4.1459 0.6922 1.5745
<eos>


In [14]:
# data loader
train_loader = create_data_loader(
    input=train_data, 
    batch_size=config["training"]["batch_size"], 
    max_tokens=config["tokenizer"]["max_tokens"], 
    stride=config["tokenizer"]["stride"],
    tokenizer=tokenizer,
    num_workers=20,
)

In [15]:
for index, (inputs, targets) in enumerate(train_loader):
    # decode
    decoded_inputs = tokenizer.decode(inputs[0].tolist())
    decoded_targets = tokenizer.decode(targets[0].tolist())    
    print(decoded_inputs)
    print(decoded_targets)
    if index == 10:
        break

C -1.2175 1.3713 0.1507
O -0.8130 2.0466 1.3430
<eos>
<LIGAND>
CC(=O)N1CCNC(=O)C1
<XYZ>
C -2.6423 0.2253 0.9374
C -1.7284 -0.7155 0.1923
O -2.1351 -1.8184 -0.1723
N -0.4298 -0.2845 -0.0371
C 0.4341 -1.1098 -0.8786
C 1.8899 -0.9812 -0.4421
N 2.3182 0.4002 -0.3945
C 1.5017 1.3928 0.0827
O 1.9732 2.4929 0.3647
C -0.0079 1.1090 0.1518
<eos>
<LIGAND>
COC(OC)[C@H](O)CO
<XYZ>
C -2.3205 -1.2753 0.0031
O -0.9909 -1.0715 0.4609
C -0.4020 0.1069 -0.1089
O -1.1040 1.2489 0.4018
C -0.7527 2.4720 -0.2263
C 1.1068 0.1368 0.2627
O 1.8158 1.1087 -0.5368
 -1.2175 1.3713 0.1507
O -0.8130 2.0466 1.3430
<eos>
<LIGAND>
CC(=O)N1CCNC(=O)C1
<XYZ>
C -2.6423 0.2253 0.9374
C -1.7284 -0.7155 0.1923
O -2.1351 -1.8184 -0.1723
N -0.4298 -0.2845 -0.0371
C 0.4341 -1.1098 -0.8786
C 1.8899 -0.9812 -0.4421
N 2.3182 0.4002 -0.3945
C 1.5017 1.3928 0.0827
O 1.9732 2.4929 0.3647
C -0.0079 1.1090 0.1518
<eos>
<LIGAND>
COC(OC)[C@H](O)CO
<XYZ>
C -2.3205 -1.2753 0.0031
O -0.9909 -1.0715 0.4609
C -0.4020 0.1069 -0.1089
O -1.1040 1