In [None]:
######################################## Pre-Training ####################################################

import pandas as pd
import torch
import re
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertForMaskedLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback
from sklearn.model_selection import train_test_split

# Define custom tokenization function for 'composition'
def custom_tokenize(composition):
    matches = re.findall(r'([A-Z][a-z]*)([0-9.]+)', composition)
    sorted_matches = sorted(matches, key=lambda x: x[0])
    tokens = []
    for match in sorted_matches:
        element, fraction = match
        token = f"{element}{fraction}"  # Combine element and fraction
        tokens.append(token)
    return ' '.join(tokens)

# Test the function
print(custom_tokenize("Co1.2 Fe0.8 Ni1"))

# Load your unlabeled data
unlabeled_data = pd.read_csv('6K.csv') #<---------------------------------------------------------------------------------

# Apply custom tokenization to 'composition' column
unlabeled_data['custom_composition'] = unlabeled_data['composition'].apply(custom_tokenize)

# Convert numeric columns to strings
numeric_cols = unlabeled_data.select_dtypes(['float64', 'int64']).columns
for col in numeric_cols:
    unlabeled_data[col] = unlabeled_data[col].astype(str)

# Concatenate them with the custom composition tokens
unlabeled_data['concat_text'] = unlabeled_data['custom_composition'] + ' ' + unlabeled_data[numeric_cols].agg(' '.join, axis=1)

# Split the data into training and validation sets
train_texts, val_texts = train_test_split(unlabeled_data['concat_text'].values, test_size=0.2, random_state=42)

# Tokenize using BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_inputs = tokenizer(list(train_texts), padding=True, truncation=True, return_tensors="pt", max_length=512)
val_inputs = tokenizer(list(val_texts), padding=True, truncation=True, return_tensors="pt", max_length=512)

# Custom Dataset class for MLM
class MLM_Dataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# Convert tokenized inputs to custom Dataset
train_dataset = MLM_Dataset(train_inputs)
val_dataset = MLM_Dataset(val_inputs)

# Data collator for MLM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

# Initialize BERT model with MLM head
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

class SaveBestModelCallback(TrainerCallback):
    """A custom callback to save the best model based on validation loss."""
    def __init__(self):
        super().__init__()
        self.best_loss = float('inf')

    def on_evaluate(self, args, state, control, **kwargs):
        if state.log_history:
            eval_loss = state.log_history[-1].get("eval_loss")
            if eval_loss and eval_loss < self.best_loss:
                self.best_loss = eval_loss
                print(f"New best model with loss: {eval_loss}, saving model...")
                model.save_pretrained("6K_Pretraining")   #<---------------------------------------------
                tokenizer.save_pretrained("6K_Pretraining") #<------------------------------------------

# Training arguments
training_args = TrainingArguments(
    output_dir="6K_Pretraining", #<---------------------------------------------------------------
    overwrite_output_dir=True,
    num_train_epochs=40,
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=50
)

# Initialize Trainer and train
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    callbacks=[SaveBestModelCallback()]
)

# Train the model
trainer.train()

# Save the final model
trainer.save_model("6K_Pretraining") #<------------------------------------------------------------------------
