In [2]:
from transformers import AutoModelForSeq2SeqLM,AutoModelForSequenceClassification, AutoTokenizer
import re
from rdkit.Chem import MolFromSmiles
from rdkit import RDLogger 
import pandas as pd
from transformers import T5ForSequenceClassification, T5Config
import torch
import os
import numpy as np
from tqdm import tqdm

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
RDLogger.DisableLog('rdApp.*')
atoms_tokens = ['Ag','Al','As','Au','B','Ba','Bi','Br','C','Ca',
              'Cd','Cl','Co','Cr','Cs','Cu','F','Fe','Ga','Gd',
              'Ge','H','Hg','I','In','K','Li','M','Mg','Mn',
              'Mo','N','Na','O','P','Pt','Ru','S','Sb','Sc',
              'Se','Si','Sn','V','W','Z','Zn','c','e','n','o','p','s']
atoms_tokens = sorted(atoms_tokens, key=lambda s: len(s), reverse=True)
SMI_REGEX_PATTERN = r"(\[|\]|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9]|" + \
                                                                  '|'.join(atoms_tokens) + ")"
regex = re.compile(SMI_REGEX_PATTERN)
def clean_output_sequence(output_sequence):
    return output_sequence.replace('</s>', '').replace('<sm_', '').replace(' sm_', '').replace('>', '').strip()
def add_special_symbols(text):
  output = []
  for word in text.split():
      tokens = [token for token in regex.findall(word)]
      if len(tokens) > 4 and (word == ''.join(tokens)) and MolFromSmiles(word):
          output.append(''.join(['<sm_'+t+'>' for t in tokens]))
      else:
          output.append(word)
  return ' '.join(output)
  

In [4]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_data = pd.DataFrame(pd.read_csv('../train_split_fluor.csv'))[['Chromophore','Solvent','Absorption max (nm)','Emission max (nm)','Quantum yield']].dropna()
validate_data  = pd.DataFrame(pd.read_csv('../test_split_fluor.csv'))[['Chromophore','Solvent','Absorption max (nm)','Emission max (nm)','Quantum yield']].dropna()

df = pd.DataFrame(train_data).dropna()
df["input_text"] = df["Chromophore"] + " " + df["Solvent"]

# Separating features and targets
X = df["input_text"].tolist()
#, "Emission max (nm)", "Quantum yield"
y = df[["Absorption max (nm)"]].values.tolist()

# Random shuffle the data
np.random.seed(42)  # For reproducibility
indices = np.arange(len(X))
np.random.shuffle(indices)

# Split the indices for training and evaluation
split = int(len(indices) * 0.8)  # 60% train, 20% eval
train_indices = indices[:split]
eval_indices = indices[split:]

# Create train and eval datasets
X_train = [add_special_symbols(X[i]) for i in train_indices]
y_train = [y[i] for i in train_indices]  # Use train_indices directly
X_eval = [add_special_symbols(X[i]) for i in eval_indices]
y_eval = [y[i] for i in eval_indices]  # Use eval_indices directly

train_labels = torch.tensor(y_train).float()
eval_labels = torch.tensor(y_eval).float()


In [5]:
# train_dataset = CustomDataset(tokenizer(X_train), train_labels)
# eval_dataset = CustomDataset(tokenizer(X_eval), eval_labels)

In [19]:
config = T5Config.from_pretrained('output/checkpoint-8684')
config.num_labels=1
model = T5ForSequenceClassification.from_pretrained('output/checkpoint-8684',
 config=config,
 ignore_mismatched_sizes=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained('output/checkpoint-8684')

In [43]:

def process_sample(sample):
    # Prepare the input string

    prompt = add_special_symbols("Given the following Chromophore and Solvent, please provide an Absorption max (nm): " + str(sample))
    print(prompt)
    # Tokenize the input
    encoding = tokenizer.encode_plus(
        prompt,
        max_length=100,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    # Move the input to the appropriate device
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    
    # Generate the prediction
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    return tokenizer.decode(outputs.logits.argmax(dim=-1)[0], skip_special_tokens=True)
   

# Process all samples in X_train
predictions = []
i = 1
for sample in tqdm(X_train):
    prediction = process_sample(sample)
    predictions.append(prediction)
    i -= 1
    if (i == 0):
        break

print(f"Generated {len(predictions)} predictions")
for i in predictions:
    print(i)

  0%|          | 0/8684 [00:00<?, ?it/s]

Given the following Chromophore and Solvent, please provide an Absorption max (nm): <sm_C><sm_N><sm_(><sm_C><sm_)><sm_c><sm_1><sm_c><sm_c><sm_c><sm_2><sm_n><sm_c><sm_3><sm_c><sm_c><sm_c><sm_(><sm_N><sm_(><sm_C><sm_)><sm_C><sm_)><sm_c><sm_c><sm_3><sm_[><sm_s><sm_+><sm_]><sm_c><sm_2><sm_c><sm_1> CCO
Generated 1 predictions






In [None]:
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer

# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    
# Training arguments
training_args = TrainingArguments(
    output_dir="./output",
    # run_name=run_name,
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    # per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    num_train_epochs=4,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to="tensorboard",
    resume_from_checkpoint=False,
    lr_scheduler_type="cosine",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset =eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)
trainer.train()
model.save_pretrained('./output/nach-model-1')
tokenizer.save_pretrained('./output/nach-model-1')



Epoch,Training Loss,Validation Loss
1,99689.48,99519.265625


There were missing keys in the checkpoint model loaded: ['transformer.encoder.embed_tokens.weight', 'transformer.decoder.embed_tokens.weight'].


('./output/nach-model-1/tokenizer_config.json',
 './output/nach-model-1/special_tokens_map.json',
 './output/nach-model-1/tokenizer.json')

476018      C   /home/tetkin/.venv/bin/python                7638MiB |
|    0   N/A  N/A    kill 498203      C   /home/kruchkov/biot5/.venv/bin/python3      18676MiB |
|    1   N/A  N/A    kill 476018      C   /home/tetkin/.venv/bin/python               10762MiB |
|    2   N/A  N/A    kill 476018      C   /home/tetkin/.venv/bin/python               10866MiB |
|    3   N/A  N/A    kill 114147      C   ...iniconda3/envs/torch_311/bin/python        446MiB |
|    3   N/A  N/A    kill 124414      C   ...iniconda3/envs/torch_311/bin/python      16874MiB |
|    3   N/A  N/A    kill 387586      C   ...iniconda3/envs/torch_311/bin/python      13900MiB |
|    3   N/A  N/A    kill 409845      C   ...iniconda3/envs/torch_311/bin/python        396MiB |

('./output/nach-model-1/tokenizer_config.json',
 './output/nach-model-1/special_tokens_map.json',
 './output/nach-model-1/tokenizer.json')