NIPS Kaggle Open Polymer Challenge
- This notebook is used to tokenize the SMILES data and extract the embeddings for feature prediction, using the TransPolymer architecture.

(Author: Saima)

In [0]:
%sh
git clone https://github.com/ChangwenXu98/TransPolymer.git

In [0]:
import sys
from pathlib import Path
import pandas as pd
import torch

TP_DIR = Path("/Transformer/TransPolymer-master")  # not added to the repo
csv_path = Path("/data/test.csv") # Update for train/test data

assert TP_DIR.exists(), f"Path not found: {TP_DIR}"
assert csv_path.exists(), f"Path not found: {csv_path}"

sys.path.append(str(TP_DIR))

VOCAB = TP_DIR / "tokenizer" / "vocab.json"
MERGES = TP_DIR / "tokenizer" / "merges.txt"

In [0]:
# Load pretrained-vocab from TransPolymer
from PolymerSmilesTokenization import PolymerSmilesTokenizer

tok = PolymerSmilesTokenizer(
  vocab_file=str(VOCAB),
  merges_file=str(MERGES),
  bos_token='<s>', eos_token='</s>', sep_token='</s>',
  cls_token="<s>", unk_token='<unk>', pad_token='<pad>', mask_token='<mask>',
)

In [0]:
# Test on one molecule
s = "CC(C)COC(=O)C1=CC=CC=C1"
enc = tok(s, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
print("input_ids", enc["input_ids"].shape)
print("attention_mask", enc["attention_mask"].shape)
print("first 30 ids: ", enc["input_ids"][0][:30].tolist())



In [0]:
# Tokenize SMILES data
MAXLEN = 128
CHUNK = 32 #1024 for train data

df = pd.read_csv(csv_path)
smiles = df["SMILES"].astype(str).tolist()

ids, masks = [], []
for i in range(0, len(smiles), CHUNK):
  batch = smiles[i:i+CHUNK]
  enc = tok(batch, padding="max_length", truncation=True, max_length=MAXLEN, return_tensors="pt")
  ids.append(enc["input_ids"])
  masks.append(enc["attention_mask"])

  input_ids = torch.cat(ids, dim=0)
  attention_mask = torch.cat(masks, dim=0)

  out = csv_path.with_suffix(".tokenized.pt")
  torch.save(
    {"input_ids": input_ids, "attention_mask": attention_mask,
     "row_index": torch.arange(len(df), dtype=torch.long)}, out
  )

  print("Saved: ", out, "Shapes: ", input_ids.shape, attention_mask.shape)