In [1]:
from datasets import Dataset
from datasets import load_dataset
import pandas as pd
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig
from transformers import Trainer, TrainingArguments
from transformers import EvalPrediction
import torch_optimizer as optim
from transformers import AutoTokenizer
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model
import transformers
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
smiles_model_name = 'seyonec/ChemBERTa-zinc-base-v1'
smiles_tokenizer = AutoTokenizer.from_pretrained(smiles_model_name)

In [3]:
dataset = load_dataset('AngryWolffyy/polymer_propery')

dataset = dataset['train']


def is_valid(row):
    return any(row[k] is not None for k in ['Tg', 'FFV', 'Tc', 'Density', 'Rg'])

dataset = dataset.filter(is_valid)

print(dataset[0])

{'id': 87817, 'SMILES': '*CC(*)c1ccccc1C(=O)OCCCCCC', 'Tg': None, 'FFV': 0.37464529, 'Tc': 0.2056666666666666, 'Density': None, 'Rg': None, 'smiles_can': '*CC(*)c1ccccc1C(=O)OCCCCCC'}


In [4]:
MAX_LENGTH = 256
def encode(row):
    smiles_encoding = smiles_tokenizer(
        row['smiles_can'], 
        padding='max_length', 
        truncation=True, 
        max_length=MAX_LENGTH, 
        add_special_tokens=False
    )

    input_ids = torch.tensor(smiles_encoding.get('input_ids', []))
    attention_mask = torch.tensor(smiles_encoding.get('attention_mask', []))

    label_values = [
        row['Tg'],
        row['FFV'],
        row['Tc'],
        row['Density'],
        row['Rg']
    ]

    existing_values = [v for v in label_values if v is not None]

    row_mean = sum(existing_values) / len(existing_values)
    
    labels = torch.tensor(
        [v if v is not None else row_mean for v in label_values],
        dtype=torch.float32
    )

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }
   


In [5]:
encoded_row = encode(dataset[0])
print(encoded_row['labels'].shape)
print(encoded_row['input_ids'].shape)

# Select input_ids where attention_mask is zero
masked_input_ids = [id for id, mask in zip(encoded_row['input_ids'], encoded_row['attention_mask']) if mask == 1]

# Convert to tensor
masked_input_ids_tensor = torch.tensor(masked_input_ids)

print(masked_input_ids_tensor)

print(encoded_row['labels'])

torch.Size([5])
torch.Size([256])
tensor([ 14, 262,  12,  14,  13,  71,  21, 269,  21,  39, 263,  51,  13,  51,
        365])
tensor([0.2902, 0.3746, 0.2057, 0.2902, 0.2902])


In [6]:
dataset = dataset.map(encode, remove_columns=dataset.column_names)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map: 100%|██████████| 7973/7973 [00:04<00:00, 1966.45 examples/s]


In [7]:
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset['train']
val_dataset = dataset['test']

In [8]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.smiles_model_name = 'seyonec/ChemBERTa-zinc-base-v1'

        smiles_config = RobertaConfig.from_pretrained(self.smiles_model_name)
        smiles_config.output_hidden_states = True
        self.config = smiles_config 

        self.smiles_model = RobertaModel.from_pretrained(self.smiles_model_name, config=smiles_config)

        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(smiles_config.hidden_size, 256)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(256)
        self.regression_head = nn.Linear(256, 5)

    def forward(self, 
                input_ids=None, 
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None, 
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None):
        
        outputs = self.smiles_model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        
        hidden_state = outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()

        sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        pooled_output = sum_embeddings / sum_mask

        x = self.fc1(pooled_output)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.norm(x)
        
        logits = self.regression_head(x)

        loss_fct = nn.MSELoss()
        loss = loss_fct(logits, labels)
        #if labels is not None:
            #loss_fct = nn.MSELoss(reduction='none')

            #loss = loss_fct(logits, labels)
            #mask = ~torch.isnan(labels)

            #masked_loss = torch.where(mask, loss, torch.zeros_like(loss))
            #loss = masked_loss.sum() / mask.sum().clamp(min=1)

        return {
            "loss": loss,
            "logits": logits
        }

In [9]:
from torch.optim import AdamW

class MyTrainer(Trainer):
    def create_optimizer(self):
        optimizer = AdamW(
            self.model.parameters(),
            lr = 3e-4,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0.01
        )
        self.optimizer = optimizer
        return optimizer

In [10]:
def compute_metrics(p: EvalPrediction):
    preds = (p.predictions)
    labels = (p.label_ids)

    mse = (root_mean_squared_error(labels, preds))**2
    mae = mean_absolute_error(labels, preds)

    return {
        'mse': mse,
        'mae': mae
    }

In [11]:
model = Model()
print([name for name, _ in model.named_modules()])

['', 'smiles_model', 'smiles_model.embeddings', 'smiles_model.embeddings.word_embeddings', 'smiles_model.embeddings.position_embeddings', 'smiles_model.embeddings.token_type_embeddings', 'smiles_model.embeddings.LayerNorm', 'smiles_model.embeddings.dropout', 'smiles_model.encoder', 'smiles_model.encoder.layer', 'smiles_model.encoder.layer.0', 'smiles_model.encoder.layer.0.attention', 'smiles_model.encoder.layer.0.attention.self', 'smiles_model.encoder.layer.0.attention.self.query', 'smiles_model.encoder.layer.0.attention.self.key', 'smiles_model.encoder.layer.0.attention.self.value', 'smiles_model.encoder.layer.0.attention.self.dropout', 'smiles_model.encoder.layer.0.attention.output', 'smiles_model.encoder.layer.0.attention.output.dense', 'smiles_model.encoder.layer.0.attention.output.LayerNorm', 'smiles_model.encoder.layer.0.attention.output.dropout', 'smiles_model.encoder.layer.0.intermediate', 'smiles_model.encoder.layer.0.intermediate.dense', 'smiles_model.encoder.layer.0.intermed

In [12]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=['query', 'value', 'output.dense'],
    lora_dropout=0.5,
    bias='none'
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 405,504 || all params: 44,708,101 || trainable%: 0.9070


In [None]:
print(transformers.__version__)
for batch in val_dataset:
    print(batch['labels'])
    break

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    lr_scheduler_type='linear',
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=3e-5,
    gradient_accumulation_steps=1,
    report_to=None,
    fp16=True,
    seed=42,
    eval_strategy='steps',
    eval_steps=20,
    save_strategy='steps',
    save_steps=99999,
    label_names=["labels"],
)

trainer = MyTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

result = trainer.train()

4.53.0
tensor([0.3848, 0.3848, 0.3848, 0.3848, 0.3848])


  return torch._C._cuda_getDeviceCount() > 0


Step,Training Loss,Validation Loss,Mse,Mae
20,203.1927,1636.790527,1636.707794,9.397746
40,1862.8959,1636.182007,1636.100105,9.36967




KeyboardInterrupt: 