In [1]:
import re
import sys
import random
from typing import List, Dict
sys.path.append('..')
from llm.tokenizer import LigandTokenizer
from IPython.display import display, HTML

def colorize_tokens(tokens: List[str], colors: List[str]) -> str:
    """
    Colorize tokens based on provided colors, preserving spaces and newlines.
    """
    colored_text = ""
    for token, color in zip(tokens, colors):
        if token == "\n":
            colored_text += "<br>"  # Replace newlines with <br> for HTML
        else:
            colored_text += f"<span style='color: {color}; white-space: pre;'>{token}</span>"
    return colored_text

def highlight_tokens_with_map(tokens: List[str], color_map: Dict[str, str]) -> str:
    """
    Highlight tokens using a predefined color map.
    """
    highlighted_text = ""
    for token in tokens:
        if token == "\n":
            highlighted_text += "<br>"  # Replace newlines with <br> for HTML
        else:
            # Escape HTML tags
            token_display = token.replace("<", "&lt;").replace(">", "&gt;")
            color = color_map.get(token, "#FFFFFF")  # Default to white if token not found
            highlighted_text += f"<span style='background-color: {color}; white-space: pre;'>{token_display}</span>"
    return highlighted_text

def generate_color_map(tokens: List[str]) -> Dict[str, str]:
    """
    Generate a unique color for each unique token.
    """
    unique_tokens = set(tokens)
    color_map = {}
    
    for token in unique_tokens:
        color_map[token] = f"#{random.randint(0, 0xFFFFFF):06x}"
    
    return color_map



In [2]:
test = """<LIGAND>
Cn1c(=O)c2c(ncn2C)n(C)c1=O
<XYZ>
C 3.2932 0.3895 0.2537
N 2.1108 -0.4198 0.1026
C 2.0698 -1.7874 0.0333
N 0.8358 -2.2304 -0.1052
C 0.0712 -1.1032 -0.1240
C 0.8238 0.0247 0.0014
C 0.2701 1.3335 0.0113
O 0.9576 2.3440 0.1250
N -1.1216 1.3332 -0.1210
C -1.8079 2.6102 -0.1253
C -1.9352 0.1881 -0.2531
O -3.1606 0.2905 -0.3648
N -1.2957 -1.0520 -0.2516
C -2.0603 -2.2787 -0.3829
<eos>"""

In [3]:
tknz = LigandTokenizer()
tknz.build_vocab([test])

In [4]:
tokens = tknz.tokenize(test)
tokens


['<LIGAND>',
 '\n',
 'C',
 'n',
 '1',
 'c',
 '(',
 '=',
 'O',
 ')',
 'c',
 '2',
 'c',
 '(',
 'n',
 'c',
 'n',
 '2',
 'C',
 ')',
 'n',
 '(',
 'C',
 ')',
 'c',
 '1',
 '=',
 'O',
 '\n',
 '<XYZ>',
 '\n',
 'C',
 ' ',
 '3',
 '.2932',
 ' ',
 '0',
 '.3895',
 ' ',
 '0',
 '.2537',
 '\n',
 'N',
 ' ',
 '2',
 '.1108',
 ' ',
 '-0',
 '.4198',
 ' ',
 '0',
 '.1026',
 '\n',
 'C',
 ' ',
 '2',
 '.0698',
 ' ',
 '-1',
 '.7874',
 ' ',
 '0',
 '.0333',
 '\n',
 'N',
 ' ',
 '0',
 '.8358',
 ' ',
 '-2',
 '.2304',
 ' ',
 '-0',
 '.1052',
 '\n',
 'C',
 ' ',
 '0',
 '.0712',
 ' ',
 '-1',
 '.1032',
 ' ',
 '-0',
 '.1240',
 '\n',
 'C',
 ' ',
 '0',
 '.8238',
 ' ',
 '0',
 '.0247',
 ' ',
 '0',
 '.0014',
 '\n',
 'C',
 ' ',
 '0',
 '.2701',
 ' ',
 '1',
 '.3335',
 ' ',
 '0',
 '.0113',
 '\n',
 'O',
 ' ',
 '0',
 '.9576',
 ' ',
 '2',
 '.3440',
 ' ',
 '0',
 '.1250',
 '\n',
 'N',
 ' ',
 '-1',
 '.1216',
 ' ',
 '1',
 '.3332',
 ' ',
 '-0',
 '.1210',
 '\n',
 'C',
 ' ',
 '-1',
 '.8079',
 ' ',
 '2',
 '.6102',
 ' ',
 '-0',
 '.1253',
 '\n',


In [5]:
ids = tknz.encode(tokens)
print(ids)

[54, 0, 58, 62, 51, 61, 2, 57, 60, 3, 61, 52, 61, 2, 62, 61, 62, 52, 58, 3, 62, 2, 58, 3, 61, 51, 57, 60, 0, 55, 0, 58, 1, 53, 34, 1, 50, 41, 1, 50, 30, 0, 59, 1, 52, 19, 1, 4, 42, 1, 50, 16, 0, 58, 1, 52, 14, 1, 5, 44, 1, 50, 11, 0, 59, 1, 50, 47, 1, 6, 27, 1, 4, 18, 0, 58, 1, 50, 15, 1, 5, 17, 1, 4, 22, 0, 58, 1, 50, 46, 1, 50, 10, 1, 50, 8, 0, 58, 1, 50, 31, 1, 51, 37, 1, 50, 9, 0, 60, 1, 50, 49, 1, 52, 38, 1, 50, 23, 0, 59, 1, 5, 21, 1, 51, 36, 1, 4, 20, 0, 58, 1, 5, 45, 1, 52, 43, 1, 4, 24, 0, 58, 1, 5, 48, 1, 50, 26, 1, 4, 29, 0, 60, 1, 7, 25, 1, 50, 33, 1, 4, 39, 0, 59, 1, 5, 35, 1, 5, 12, 1, 4, 28, 0, 58, 1, 6, 13, 1, 6, 32, 1, 4, 40, 0, 56]


In [6]:
# # Assign random background colors to tokens for highlighting
# colors = [f"#{random.randint(0, 0xFFFFFF):06x}" for _ in tokens]

# # Generate the colored HTML
# colored_text = highlight_tokens(tokens, colors)

# # Display the colored tokens in a notebook
# display(HTML(colored_text))

In [7]:
# Generate a unique color for each unique token
color_map = generate_color_map(tokens)

# Generate the highlighted HTML with unique colors for each token
highlighted_text = highlight_tokens_with_map(tokens, color_map)

# Display the highlighted tokens in a notebook
display(HTML(highlighted_text))

# Data Loader

In [8]:
import sys
import yaml
sys.path.append('..')
from llm.preprocessing.loader import create_data_loader, split_data

In [9]:
with open('../data/xyz_mols/cleaned_input.txt', 'r') as f:
    data = f.read()

In [10]:
config = yaml.load(open('../llm/config/train.yml', 'r'), Loader=yaml.FullLoader)

In [11]:
# tokenizer
tokenizer = LigandTokenizer()
tokenizer.build_vocab([data])
# save tokenizer
tokenizer.save("tokenizer.json")
# split data
train_data, val_data = split_data(data)


Removing existing file: tokenizer.json


In [12]:
# print last 500 tokens
print(train_data[-500:])

26 1.2581 0.7176
C -1.1890 0.6825 0.2245
C 0.0434 1.2783 0.4748
C 1.2073 0.6520 0.0395
N 2.4680 1.2245 0.3325
C 1.0927 -0.6365 -0.5203
N 2.2523 -1.3463 -0.8767
N -0.0768 -1.2666 -0.7231
C -1.1855 -0.6135 -0.3318
N -2.4076 -1.2884 -0.5043
<eos>
<LIGAND>
Nc1ccc(N)c(N)n1
<XYZ>
N 2.9131 0.2017 -0.9565
C 1.6077 0.1051 -0.4601
C 1.3988 -0.0562 0.9045
C 0.0949 -0.1545 1.3839
C -0.9720 -0.0565 0.4969
N -2.2880 -0.2737 0.9531
C -0.6587 0.0265 -0.8762
N -1.7052 0.0176 -1.8227
N 0.5982 0.1384 -1.3425
<eos>


In [13]:
# data loader
train_loader = create_data_loader(
    input=train_data, 
    batch_size=config["training"]["batch_size"], 
    max_tokens=config["tokenizer"]["max_tokens"], 
    stride=config["tokenizer"]["stride"],
    tokenizer=tokenizer,
    num_workers=1,
)

In [16]:
for index, (inputs, targets) in enumerate(train_loader):
    # decode
    decoded_inputs = tokenizer.decode_tkn(inputs[0])
    decoded_targets = tokenizer.decode_tkn(targets[0])
    print(decoded_inputs)
    print(decoded_targets)
    if index == 10:
        break

['.5893', '\n', 'O', ' ', '-2', '.0711', ' ', '1', '.3342', ' ', '0', '.6820', '\n', '<eos>', '\n', '<LIGAND>', '\n', 'O', 'C', '[', 'C', '@', '@', ']', '1', '(', 'O', ')', 'O', 'C', '[', 'C', '@', '@', 'H', ']', '(', 'O', ')', '[', 'C', '@', 'H', ']', '(', 'O', ')', '[', 'C', '@', 'H', ']', '1', 'O', '\n', '<XYZ>', '\n', 'O', ' ', '2', '.9077', ' ', '0', '.3289', ' ', '0', '.5185', '\n', 'C', ' ', '1', '.7827', ' ', '-0', '.5494', ' ', '0', '.6086', '\n', 'C', ' ', '0', '.9002', ' ', '-0', '.4432', ' ', '-0', '.6662', '\n', 'O', ' ', '1', '.7807', ' ', '-0', '.4008', ' ', '-1', '.8035', '\n', 'O', ' ', '0', '.0810', ' ', '-1', '.6067', ' ', '-0', '.8037', '\n', 'C', ' ', '-0', '.9973', ' ', '-1', '.6888', ' ', '0', '.1282', '\n', 'C', ' ', '-1', '.9232', ' ', '-0', '.4788', ' ', '0', '.0475', '\n', 'O', ' ', '-2', '.9323', ' ', '-0', '.5583', ' ', '1', '.0552', '\n', 'C', ' ', '-1', '.1178', ' ', '0', '.8039', ' ', '0', '.2377', '\n', 'O', ' ', '-1', '.9830', ' ', '1', '.9351', ' ', '