# Fine-Tuning ChemGPT

## Librairies

In [1]:
# !pip install torch torchvision torchaudio
# !pip install --upgrade "transformers[torch]" accelerate
# !pip install transformers
# !pip install datasets

In [2]:
# On local MacOS
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


## Fine-Tuning

In [5]:
# Load pretrained ChemGPT (SELFIES version if available)
model_name = "ncfrey/ChemGPT-4.7M" # Or smaller variant for testing

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
model.to(device)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(684, 128)
    (wpe): Embedding(2048, 128)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPTNeoBlock(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=128, out_features=128, bias=False)
            (v_proj): Linear(in_features=128, out_features=128, bias=False)
            (q_proj): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=True)
          )
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=128, out_features=512, bias=True)
          (c_proj): Linear(in_featu

In [6]:
# Load dataset
dataset = load_dataset("csv", data_files="../dataset/toy_dataset.csv")

In [7]:
dataset["train"][3]

{'selfies': '\\[C]\\[C]\\[C]\\[C]\\[O]\\[H]', 'solubility': 'medium'}

In [8]:
# Preprocess: combine property + selfies
def preprocess(examples):
    sequences = [
        f"solubility={prop} [START] {s}" 
        for s, prop in zip(examples["selfies"], examples["solubility"])
    ]
    
    tokenized = tokenizer(
        sequences,
        padding="max_length",    # pad all sequences to max_length
        truncation=True,         # truncate longer sequences
        max_length=64,           # or a length that fits your data
    )
    
    # For causal LM, labels = input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized

tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)

In [9]:
# Training setup
training_args = TrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    save_strategy="epoch",
    fp16=False,  # True if GPU supports, False for M1 Mac
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    tokenizer=tokenizer,
)

# Fine-tune
trainer.train()

# Save fine-tuned model
trainer.save_model("./checkpoints")

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': None, 'bos_token_id': None, 'pad_token_id': 1}.


Step,Training Loss




In [10]:
# # Path to your fine-tuned model
model_path = "./checkpoints"

# # Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForCausalLM.from_pretrained(model_path)
model.to(device)

# Inference function
def generate_molecule(solubility="high", max_length=32, max_new_tokens=50,
                      temperature=0.7, top_p=0.95, top_k=50):
    """
    Generate a SELFIES string conditioned on solubility.
    
    Args:
        solubility (str): property prompt, e.g., "high" or "low"
        max_length (int): maximum length for tokenizing the prompt
        max_new_tokens (int): number of tokens to generate
        temperature (float): randomness for sampling
        top_p (float): nucleus sampling
        top_k (int): top-k sampling
    Returns:
        str: generated SELFIES string
    """
    # Build prompt
    prompt = f"solubility={solubility} [START]"

    # Tokenize with padding and truncation
    encoding = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length
    )
    
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    
    # Generate new tokens
    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        pad_token_id=tokenizer.pad_token_id
    )
    
    # Decode and remove prompt
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_selfies = generated_text[len(prompt):].strip()
    
    return generated_selfies

In [11]:
# Example usage
print("Generated molecule for low solubility:", generate_molecule("low"))
print("Generated molecule for high solubility:", generate_molecule("high"))

Generated molecule for low solubility: ch1_1] [C] [P] [Branch1_1] [=N] [Siexpl] [Branch1_1] [Ring1] [Seexpl] [Seexpl] [Teexpl] [Teexpl] [Seexpl] [Teexpl] [Ring1] [Branch1_2] [S] [Geexpl] [#S] [Ring1] [=N] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [Branch1_2] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [Branch1_2] [C] [Branch1_2] [C] [=O] [N] [C] [C]
Generated molecule for high solubility: ] [Ring1] [Branch2_3] [S] [Seexpl] [Seexpl] [Seexpl] [Seexpl] [Seexpl] [PHexpl] [Ring1] [Branch1_1] [S] [Ring1] [Branch2_1] [S] [Ring1] [N] [C] [C] [C] [C] [Branch1_1] [C] [C] [C] [Branch1_1] [C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [Branch1_2] [C] [C] [C] [C] [C] [C] [C]
