In [None]:
import torch
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import wandb

In [None]:
filename = 'ecfp0'
samples_count1 = '10M'
samples_count2 = '2M'
model_name1 = f'molberto_{filename}_{samples_count1}'
model_name2 = f'molberto_{filename}_{samples_count2}'

In [None]:
target = 'NR-AR'
# molecular_properties = ['Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa']

### Upload and Split Dataset

In [None]:
dataframe = pd.read_csv("tox21-ecfp.csv")

In [None]:
# dataframe = dataframe.drop(columns=['Unnamed: 0', 'Smiles', 'ecfp2', 'ecfp3'])

In [None]:
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [None]:
preprocess_data_dataset(dataframe, 'ecfp')

In [None]:
dataframe

In [None]:
dataframe = dataframe.dropna(subset=[target]).reset_index(drop=True)

In [None]:
dataframe = dataframe[0:3500]

In [None]:
# print('Percentage on NaNs:')
# dataframe.isna().mean()

In [None]:
# rows_with_nans = dataframe['Molecular Weight'].isna() | \
#                  dataframe['Bioactivities'].isna() | \
#                  dataframe['AlogP'].isna() | \
#                  dataframe['Polar Surface Area'].isna() | \
#                  dataframe['CX Acidic pKa'].isna() | \
#                  dataframe['CX Basic pKa'].isna()
# print(f'Count of rows without NaNs: {dataframe.shape[0] - dataframe.loc[rows_with_nans].shape[0]}')

In [None]:
# remove 2 last properties to reduce NaN counts
# molecular_properties = ['Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area']
# dataframe = dataframe.drop(columns=['CX Acidic pKa', 'CX Basic pKa'])

In [None]:
# drop NaN's
# dataframe = dataframe.dropna().reset_index(drop=True)

In [None]:
# dataframe

In [None]:
from datasets import Dataset, DatasetDict

dataset = Dataset.from_pandas(dataframe)
train_testvalid = dataset.train_test_split(test_size=0.2, seed=15)

test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=15)

# 10% for test, 10 for validation, 80% for train
dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'validation': test_valid['train']})

dataset

### Tokenize Data

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name1)

tokenizer.model_max_len=512

In [None]:
def tokenize(batch):
  return tokenizer(batch["ecfp"], truncation=True, max_length=512, padding='max_length')

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset

In [None]:
columns = ["input_ids", "attention_mask"]
columns.extend([target]) # our labels
print(columns)
tokenized_dataset.set_format('torch', columns=columns)

from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### Create Transformer Model

In [None]:
from transformers import AutoModel, AutoConfig

class MolecularPropertiesClassification(torch.nn.Module):
    def __init__(self, model_name1, model_name2):
        super(MolecularPropertiesClassification, self).__init__()

        # config1 = AutoConfig.from_pretrained(model_name1)
        # self.transformer1 = AutoModel.from_pretrained(model_name1, config=config1)
        # # removing last layer of transformer
        # self.transformer1.pooler = torch.nn.Identity()
        # # freezing transformer weights
        # for param in self.transformer1.parameters():
        #     param.requires_grad = False
        
        config2 = AutoConfig.from_pretrained(model_name2)
        self.transformer2 = AutoModel.from_pretrained(model_name2, config=config2)
        # removing last layer of transformer
        self.transformer2.pooler = torch.nn.Identity()
        # freezing transformer weights
        for param in self.transformer2.parameters():
            param.requires_grad = False

        self.linear1 = torch.nn.Linear(768, 768, bias=True)
        self.linear2 = torch.nn.Linear(768, 2, bias=True)

    def forward(self, input_ids = None, attention_mask=None):
        # outputs1 = self.transformer1(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = self.transformer2(input_ids=input_ids, attention_mask=attention_mask)
        # last_hidden_state1 = outputs1[0]
        last_hidden_state2 = outputs2[0]
        
        first_linear_out = self.linear1(last_hidden_state2[:, 0, : ].view(-1, 768))
        logits = self.linear2(torch.nn.functional.sigmoid(first_linear_out))

        return logits
        

### Create PyTorch DataLoader

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_dataset['train'], shuffle = True, batch_size = 64, collate_fn = data_collator
)

eval_dataloader = DataLoader(
    tokenized_dataset['validation'], shuffle = True, batch_size = 64, collate_fn = data_collator
)

In [None]:
device = torch.device("cuda", index=5) if torch.cuda.is_available() else torch.device('cpu')

model = MolecularPropertiesClassification(model_name1, model_name2).to(device)

In [None]:
model

In [None]:
from transformers import AdamW, get_scheduler

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epoch = 100

num_training_steps = num_epoch * len(train_dataloader)

lr_scheduler = get_scheduler(
    'linear',
    optimizer = optimizer,
    num_warmup_steps = 0,
    num_training_steps = num_training_steps,
)

loss_func = torch.nn.CrossEntropyLoss()

In [None]:
wandb.init(
    project="efcp_transformer",
    name="SingleTransformer with LinearClassifier tox21 training",
    config={}
)

### Training

In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epoch * len(eval_dataloader)))

for epoch in range(num_epoch):
    model.train()
    total_pred_labels = []
    total_true_labels = []
    epoch_loss = 0
    for batch in train_dataloader:
        input_batch = { k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask'] }
        batch[target] = batch[target].to(device)
        
        logits = model(**input_batch)
        
        loss = loss_func(logits.view(-1, 2), batch[target].view(-1).type(torch.cuda.LongTensor))
        loss.backward()
        epoch_loss += loss.item()
        
        pred_labels = torch.argmax(logits, dim=-1)
        true_labels = batch[target]
        total_pred_labels.append(pred_labels)
        total_true_labels.append(true_labels)
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.update(1)

    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()
    
    wandb.log({"loss/train": epoch_loss / len(train_dataloader)}, step=epoch)
    wandb.log({"accuracy/train": accuracy_score(total_true_labels, total_pred_labels)}, step=epoch)
    wandb.log({"f1/train": f1_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"precision/train": precision_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"recall/train": recall_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)

    model.eval()
    total_pred_labels = []
    total_true_labels = []
    epoch_loss = 0
    for batch in eval_dataloader:
        input_batch = { k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask'] }
        batch[target] = batch[target].to(device)
        
        with torch.no_grad():
            logits = model(**input_batch)
            loss = loss_func(logits.view(-1, 2), batch[target].view(-1).type(torch.cuda.LongTensor))
            epoch_loss += loss.item()

            pred_labels = torch.argmax(logits, dim=-1)
            true_labels = batch[target]
            total_pred_labels.append(pred_labels)
            total_true_labels.append(true_labels)
        
        progress_bar_eval.update(1)

    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()
    
    wandb.log({"loss/validation": epoch_loss / len(eval_dataloader)}, step=epoch)
    wandb.log({"accuracy/validation": accuracy_score(total_true_labels, total_pred_labels)}, step=epoch)
    wandb.log({"f1/validation": f1_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"precision/validation": precision_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)
    wandb.log({"recall/validation": recall_score(total_true_labels, total_pred_labels, average='micro')}, step=epoch)

In [None]:
wandb.finish()

In [None]:
torch.cuda.empty_cache()