In [None]:
# !pip install numpy
# !pip install torch
# !pip install transformers
# !pip install tensorflowÍ
# !pip install wandb

: 

In [None]:
import torch
from tqdm.auto import tqdm
import wandb
from importlib import reload
import pandas as pd

In [None]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

In [None]:
samples_count = '100k'
model_name = f'd-bert_{samples_count}'

In [None]:
tqdm.pandas()

def read_data():
    data = pd.read_csv(f'out-{samples_count}.csv')
    import ast

    def string_to_array(input_string):
        try:
            # Use ast.literal_eval to safely evaluate the string as a Python literal
            result = ast.literal_eval(input_string)
            return result
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing the string: {e}")
            return None
    data['descriptors'] = data['descriptors'].progress_apply(lambda x: string_to_array(x))
    return data

In [None]:
data = read_data()
data['descriptors'][0]

In [None]:
import shifter as sh
reload(sh)

shifter = sh.Shifter()

In [None]:
for descriptors_of_substructures in data['descriptors']:
    shifter.shift(descriptors_of_substructures)

In [None]:
maximum = 0
for mol in data['descriptors']:
    for substr in mol:
        if substr == '$':
            continue
        for descriptor in substr:
            for i in descriptor:
                maximum = max(maximum, i)
maximum # vocab_size

In [None]:
import tokenizer as tokenizer
reload(tokenizer)
# tokenized_descriptors = tokenizer.tokenize(data['descriptors'], max_length=513)

In [None]:
def mlm(tensor):
    print(tensor)
    # create random array of floats with equal dims to tensor
    rand = torch.rand(tensor.shape)
    # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
    mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
    # loop through each row in tensor (cannot do in parallel)
    for i in range(tensor.shape[0]):
        # get indices of mask positions from mask array
        selection = torch.flatten(mask_arr[i].nonzero()).tolist()
        # mask tensor
        tensor[i, selection] = 4
    return tensor

In [None]:
def tokenize_descriptors(data, start, end = -1):
    input_ids = []
    mask = []
    labels = []
    sample = tokenizer.tokenize(data['descriptors'][start:end], max_length=512)
    
    labels.append(torch.tensor(sample['input_ids']))
    mask.append(torch.tensor(sample['attention_mask']))
    input_ids.append(mlm(labels[-1].detach().clone())) # mask ~15% of tokens to create inputs
    
    input_ids = torch.cat(input_ids)
    mask = torch.cat(mask)
    
    labels = torch.cat(labels)
    return input_ids, mask, labels

In [None]:
input_ids, mask, labels = tokenize_descriptors(data, 0, 1)

In [None]:
train_input_ids, train_mask, train_labels = tokenize_descriptors(data, 0, 80000)
validation_input_ids, validation_mask, validation_labels = tokenize_descriptors(data, 80000, 90000)
test_input_ids, test_mask, test_labels = tokenize_descriptors(data, 90000)

In [None]:
print(train_input_ids.shape)
print(validation_input_ids.shape)
print(validation_input_ids.shape)

In [None]:
train_labels[0]

In [None]:
# torch.save(input_ids, 'molberto_training/input_ids.pt')
# torch.save(mask, 'molberto_training/attention_mask.pt')
# torch.save(labels, 'molberto_training/labels.pt')

# del input_ids, mask, labels

In [None]:
# input_ids = torch.load('molberto_training/input_ids.pt')
# mask = torch.load('molberto_training/attention_mask.pt')
# labels = torch.load('molberto_training/labels.pt')

### dataset and dataloader

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, i):
        return {key: tensor[i] for key, tensor in self.encodings.items()}

In [None]:
train_dataset = Dataset({'input_ids': train_input_ids, 'attention_mask': train_mask, 'labels': train_labels})
validation_dataset = Dataset({'input_ids': validation_input_ids, 'attention_mask': validation_mask, 'labels': validation_labels})
test_dataset = Dataset({'input_ids': test_input_ids, 'attention_mask': test_mask, 'labels': test_labels})

In [None]:
batch_size = 128

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

And move onto building our model, we first need to create a RoBERTa config object, which will describe which features we want to initialize our RoBERTa model with.

In [None]:
from transformers import RobertaConfig

config = RobertaConfig(
    vocab_size=maximum + 1,
    max_position_embeddings=514,
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1
)

Then we import and initialize a RoBERTa model with a language modeling head.

In [None]:
from transformers import RobertaForMaskedLM

model = RobertaForMaskedLM(config)

And now we move onto training. First we setup GPU/CPU usage.

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.cuda.is_available()

In [None]:
device = torch.device('cuda', index=4) if torch.cuda.is_available() else torch.device('cpu')
# and move our model over to the selected device
model.to(device)

Activate the training mode of our model, and initialize our optimizer (Adam with weighted decay - reduces chance of overfitting).

In [None]:

from transformers import AdamW

# activate training mode
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr=1e-5)

In [None]:
wandb.init(
    project="bert_transformer",
    name="RobertaForMLM on molecular descriptors training (100k)",
    config=config
)

Now we move onto the training loop.

In [None]:
%env CUDA_LAUNCH_BLOCKING=1

In [None]:
from tqdm import tqdm  # for our progress bar
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

epochs = 10
step = 0

validation_iterator = iter(validation_loader)
for epoch in tqdm(range(epochs)):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask,
                        labels=labels)
        loss = outputs.loss
        logits = outputs.logits
        true_labels = batch['labels'].numpy().flatten()
        pred_labels = torch.nn.functional.softmax(logits, dim=1).argmax(axis=-1).cpu().detach().numpy().flatten()

        # write down loss and metrics
        wandb.log({"loss/train": loss}, step=step)
        wandb.log({"accuracy/train": accuracy_score(true_labels, pred_labels)}, step=step)
        wandb.log({"f1/train": f1_score(true_labels, pred_labels, average='micro')}, step=step)
        wandb.log({"precision/train": precision_score(true_labels, pred_labels, average='micro')}, step=step)
        wandb.log({"recall/train": recall_score(true_labels, pred_labels, average='micro')}, step=step)
        
        loss.backward()
        optim.step()
        optim.zero_grad()

        with torch.no_grad():
            try:
                validation_batch = next(validation_iterator)
            except StopIteration:
                print("STOP_ITERATION")
                
                validation_dataset = Dataset({'input_ids': validation_input_ids, 'attention_mask': validation_mask, 'labels': validation_labels})
                validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=32, shuffle=False, drop_last=True)
                validation_iterator = iter(validation_loader)
                
                validation_batch = next(validation_iterator)
            
            input_ids = validation_batch['input_ids'].to(device)
            attention_mask = validation_batch['attention_mask'].to(device)
            labels = validation_batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs.loss
            logits = outputs.logits
            true_labels = batch['labels'].numpy().flatten()
            # print("TRUE_LABELS")
            # print(len(true_labels))
            # print(true_labels.shape)
            # print(true_labels)
            pred_labels = torch.nn.functional.softmax(logits, dim=1).argmax(axis=-1).cpu().detach().numpy().flatten()
            # print("PRED_LABELS")
            # print(len(pred_labels))
            # print(pred_labels.shape)
            # print(pred_labels)
    
            # write down loss and metrics
            wandb.log({"loss/validation": loss}, step=step)
            wandb.log({"accuracy/validation": accuracy_score(true_labels, pred_labels)}, step=step)
            wandb.log({"f1/validation": f1_score(true_labels, pred_labels, average='micro')}, step=step)
            wandb.log({"precision/validation": precision_score(true_labels, pred_labels, average='micro')}, step=step)
            wandb.log({"recall/validation": recall_score(true_labels, pred_labels, average='micro')}, step=step)
            
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        if step > 2500:
            break
        step += len(batch)


In [None]:
wandb.finish()

In [None]:
wandb.init(
    project="bert_transformer",
    name="RobertaForMLM on molecular descriptors testing (100k)",
    config=config
)

In [None]:
step = 0

with torch.no_grad():
    loop = tqdm(test_loader, leave=True)
    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask,
                        labels=labels)
        loss = outputs.loss
        logits = outputs.logits
        true_labels = batch['labels'].numpy().flatten()
        pred_labels = torch.nn.functional.softmax(logits, dim=1).argmax(axis=-1).cpu().detach().numpy().flatten()

        # write down loss and metrics
        wandb.log({"loss/test": loss}, step=step)
        wandb.log({"accuracy/test": accuracy_score(true_labels, pred_labels)}, step=step)
        wandb.log({"f1/test": f1_score(true_labels, pred_labels, average='micro')}, step=step)
        wandb.log({"precision/test": precision_score(true_labels, pred_labels, average='micro')}, step=step)
        wandb.log({"recall/test": recall_score(true_labels, pred_labels, average='micro')}, step=step)
        
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        step += len(batch)

wandb.finish()

In [None]:
wandb.finish()

In [None]:
model.save_pretrained(model_name)

In [None]:
torch.cuda.empty_cache()

In [None]:
print(torch.cuda.device_count())
print(torch.cuda.current_device())

In [None]:
torch.device('cuda', index=1)