In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")

    # Optional: print the name of each GPU
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available. No GPUs found.")

Number of available GPUs: 1
GPU 0: Tesla T4


In [None]:
! pip install bitsandbytes
! pip install peft
! pip install --pre deepchem
! pip install ai2-olmo

Collecting ai2-olmo
  Using cached ai2_olmo-0.6.0-py3-none-any.whl.metadata (25 kB)
Collecting ai2-olmo-core==0.1.0 (from ai2-olmo)
  Using cached ai2_olmo_core-0.1.0-py3-none-any.whl.metadata (14 kB)
Collecting boto3 (from ai2-olmo)
  Downloading boto3-1.42.36-py3-none-any.whl.metadata (6.8 kB)
Collecting cached_path>=1.6.2 (from ai2-olmo)
  Downloading cached_path-1.8.1-py3-none-any.whl.metadata (19 kB)
Collecting botocore<1.43.0,>=1.42.36 (from boto3->ai2-olmo)
  Downloading botocore-1.42.36-py3-none-any.whl.metadata (5.9 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from boto3->ai2-olmo)
  Downloading jmespath-1.1.0-py3-none-any.whl.metadata (7.6 kB)
Collecting s3transfer<0.17.0,>=0.16.0 (from boto3->ai2-olmo)
  Downloading s3transfer-0.16.0-py3-none-any.whl.metadata (1.7 kB)
Downloading ai2_olmo-0.6.0-py3-none-any.whl (144.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.9/144.9 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ai2_olmo_core-0.1.0-py3

In [None]:
! pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.6.0-py3-none-any.whl (849 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m35.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.15.2 pytorch_lightning-2.6.0 torchmetrics-1.8.2


In [None]:
%%writefile train.py
import torch
import pytorch_lightning as pl
import deepchem as dc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
from deepchem.molnet import load_delaney
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from sklearn.metrics import mean_squared_error
import seaborn as sns
import re

class OlmoDataset(Dataset):
    def __init__(self, mode="Train", max_length=350):
        self.tokenizer = AutoTokenizer.from_pretrained(
            "Codemaster67/OLMo-7B-USPTO-1k-ZINC",
            trust_remote_code=True,
            padding_side="right"
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token

        tasks, datasets, transformers = load_delaney(featurizer="raw", splitter='scaffold')
        train, valid, test = datasets

        self.mode = mode.lower()
        if self.mode == "train":
            self.data = train
        elif self.mode == "valid":
            self.data = valid
        elif self.mode == "test":
            self.data = test

        self.max_length = max_length
        self.samples = []
        self._filldataset()

    def _filldataset(self):
        for i in range(len(self.data)):
            smiles = self.data.ids[i]
            label = self.data.y[i][0]

            self.samples.append(self._create_prompt(smiles, label))
        print(f"[{self.mode.upper()}] Number of samples: {len(self.samples)}")

    def _create_prompt(self, smiles, label):
        eos_token = self.tokenizer.eos_token
        answer = f"{label:.5f}"

        full_prompt = (
            "### Instruction:\n"
            "Predict the ESOL water solubility"
            f"for the following molecule:\n{smiles}\n\n"
            "### Response:\n"
            f"{answer}{eos_token}"
        )
        return full_prompt

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text = self.samples[idx]
        encodings = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        input_ids = encodings["input_ids"].squeeze(0)
        attention_mask = encodings["attention_mask"].squeeze(0)
        labels = input_ids.clone()

        separator = "### Response:\n"
        parts = text.split(separator)

        if len(parts) >= 2:
            prompt_text = parts[0] + separator
            prompt_encodings = self.tokenizer(
                prompt_text,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            prompt_len = prompt_encodings["input_ids"].shape[1]

            if prompt_len < len(labels):
                labels[:prompt_len] = -100
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

class OLMO_QLoRA(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(
            "Codemaster67/OLMo-7B-USPTO-1k-ZINC",
            trust_remote_code=True,
            padding_side="right"
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token


        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )

        self.peft_config = LoraConfig(
            r=32,
            lora_alpha=64,
            target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )
    def configure_model(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "Codemaster67/OLMo-7B-USPTO-1k-ZINC",
            quantization_config=self.bnb_config,
            trust_remote_code=True,
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, self.peft_config)
        self.model.print_trainable_parameters()

    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        loss = outputs.loss
        self.log("Train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        loss = outputs.loss

        return loss
    def on_train_end(self):
            if self.trainer.is_global_zero:
                print("\nStarting test set evaluation (RMSE) after training...")

                test_dataset = OlmoDataset(mode="test", max_length=350)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

                self.model.eval()

                preds = []
                trues = []

                # Get ground truth values
                true_values = test_dataset.data.y.flatten()

                print(f"Evaluating on {len(test_loader)} samples...")

                with torch.no_grad():
                    for i, batch in enumerate(test_loader):
                        # Move batch to the correct device
                        batch = {k: v.to(self.device) for k, v in batch.items()}

                        input_ids = batch["input_ids"]
                        labels = batch["labels"]
                        attention_mask = batch["attention_mask"]

                        # Logic to slice the input so we only feed the prompt to generate()
                        # We look for the first index where labels are NOT -100 (which marks the start of the answer)
                        response_mask = (labels != -100)
                        answer_start_index = response_mask.int().argmax(dim=1).item()

                        # Slice input_ids to keep only the instruction + input (remove the ground truth answer)
                        if answer_start_index > 0:
                            prompt_ids = input_ids[:, :answer_start_index]
                            # FIX 2: SLICE THE ATTENTION MASK TOO!
                            prompt_mask = attention_mask[:, :answer_start_index]
                        else:
                            prompt_ids = input_ids
                            prompt_mask = attention_mask

                        outputs = self.model.generate(
                            input_ids=prompt_ids,
                            attention_mask=prompt_mask,
                            max_new_tokens=10,
                            pad_token_id=self.tokenizer.eos_token_id,
                            do_sample=False
                        )

                        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                        # Extract the number from the response
                        try:
                            if "### Response:" in generated_text:
                                response_part = generated_text.split("### Response:")[-1].strip()
                            else:
                                response_part = generated_text.strip()

                            # Regex to find float numbers (handles negative and decimals)
                            match = re.search(r"(?<!\w)-?\d+(?:\.\d+)?(?!\w)", response_part)
                            if match:
                                val = float(match.group())
                            else:
                                print(f"Warning: Could not parse number from: {response_part[:50]}...")
                                val = 0.0
                        except Exception as e:
                            print(f"Error parsing prediction: {e}")
                            val = 0.0

                        preds.append(val)
                        trues.append(true_values[i])

                        if i % 10 == 0:
                            print(f"Sample {i}: True={true_values[i]:.5f}, Pred={val:.5f}")

                preds = np.array(preds)
                trues = np.array(trues)

                # Calculate RMSE
                rmse = np.sqrt(mean_squared_error(trues, preds))

                print("\n=== Test Set Metrics ===")
                print(f"RMSE: {rmse:.4f}")


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2e-4,weight_decay=1e-4)

        total_steps = self.trainer.estimated_stepping_batches

        warmup_steps = 10
        scheduler_warmup = LinearLR(
            optimizer,
            start_factor=0.001,
            end_factor=1.0,
            total_iters=warmup_steps,
        )


        scheduler_cosine = CosineAnnealingLR(
            optimizer,
            T_max=total_steps - warmup_steps,
        )

        scheduler = SequentialLR(
            optimizer,
            schedulers=[scheduler_warmup, scheduler_cosine],
            milestones=[warmup_steps]
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            }
        }


if __name__ == "__main__":
    dataset = OlmoDataset()
    valid_dataset = OlmoDataset(mode="valid")

    train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

    trainer = pl.Trainer(
            accelerator="gpu",
            devices=1,
            strategy="ddp",
            max_epochs=16,
            precision="16-mixed",
            accumulate_grad_batches=16,
            enable_checkpointing=False,
            gradient_clip_val=1,
        )

    model = OLMO_QLoRA()

    trainer.fit(model, train_loader)

Writing train.py


In [None]:
! python train.py

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
No normalization for NumAmideBonds. Feature removed!
No normalization for NumAtomStereoCenters. Feature removed!
No normalization for NumBridgeheadAtoms. Feature removed!
No normalization for NumHeterocycles. Feature removed!
No normalization for NumSpiroAtoms. Feature removed!
No normalization for NumUnspecifiedAtomStereoCenters. Feature removed!
No normalization for Phi. Feature removed!
2026-01-28 16:13:03.506763: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769616783.546344    1603 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769616783.557222    1603 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin