# Model Training


In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# model_name = "dmis-lab/biobert-base-cased-v1.1"
# model_name = "nlpie/distil-biobert"
model_name = "distilbert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Model Parameters

In [None]:
# test = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=15)
# for name, p in test.named_parameters():
#     print(name)
# test.config.hidden_size

### Preparing Dataset


In [None]:
augmented_data = pd.read_csv("augmented_labelled_data.csv")
augmented_data.shape

(3081, 3)

In [None]:
class BMSDataset(Dataset):
    def __init__(self, encodings, parent_labels, child_labels):
        self.encodings = encodings
        self.parent_labels = parent_labels
        self.child_labels = child_labels

    def __getitem__(self, idx):
        encoding = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
        parent_label = torch.tensor(self.parent_labels[idx]).to(device)
        child_label = torch.tensor(self.child_labels[idx]).to(device)
        # return encoding, parent_label, child_label
        return {"encodings": encodings, "parent_label": parent_label, "child_label": child_label}

    def __len__(self):
        return len(self.child_labels)

parent_label2id = {
  "ce-sds (non-reduced) hhl": 0,
  "ce-sds (non-reduced) purity": 0,
  "ce-sds (reduced) purity": 0,
  "sds-page (non-reduced) purity": 0,
  "sds-page (reduced) purity": 0,
  "ief acidic peaks": 1,
  "ief basic peaks": 1,
  "ief main peak": 1,
  "icief acidic peaks": 1,
  "icief basic peaks": 1,
  "icief main peak": 1,
  "cex acidic peaks": 1,
  "cex basic peaks": 1,
  "cex main peak": 1,
  "aex acidic peaks": 1,
  "aex basic peaks": 1,
  "aex main peak": 1,
  "rp-hplc purity": 2,
  "se-hplc hmw": 2,
  "se-hplc lmw": 2,
  "se-hplc monomer": 2,
  "se-uplc hmw": 2,
  "se-uplc lmw": 2,
  "se-uplc monomer": 2,
  "particulate-matter >= 10-um": 3,
  "particulate-matter >= 25-um": 3,
  "potency by cell-based bioassay": 4,
  "potency by binding elisa": 4,
  "spr binding activity": 4,
  "ph": 5,
  "protein concentration (a280)": 6,
  "polysorbate 80": 7,
  "irrelevant": 8
}
child_label2id = {label: idx for idx, label in enumerate(parent_label2id.keys())}

child_id2label = {idx: label for label, idx in child_label2id.items()}

parent_labels = [parent_label2id[label.lower()] for label in augmented_data["Standard names"].tolist()]
child_labels = [child_label2id[label.lower()] for label in augmented_data["Standard names"].tolist()]


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
encodings = tokenizer(augmented_data["Analysis"].tolist(), augmented_data["Attribute"].tolist(), max_length=512, padding="max_length", truncation=True)
encodings.keys()

dict_keys(['input_ids', 'attention_mask'])

In [None]:
train_dataset = BMSDataset(encodings, parent_labels, child_labels)

### Model Definition

In [None]:
class BMSModel(torch.nn.Module):
    def __init__(self, model_name: str, parent_class_count: int, child_class_count: int):
        super(BMSModel, self).__init__()

        self.parent_classifier = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=parent_class_count)
        self.parent_classifier.config.output_hidden_states = True
        embedding_dim = self.parent_classifier.config.hidden_size

        # Freezing the transformer backbone weights before layer 7
        tuneable_layers = ["bert.encoder.layer.7", "bert.encoder.layer.8", "bert.encoder.layer.9", "bert.encoder.layer.10", "bert.encoder.layer.11", "bert.pooler.dense", "classifier"]
        for name, param in self.parent_classifier.named_parameters():
            param.requires_grad = False
            for layer in tuneable_layers:
                if name.startswith(layer):
                    param.requires_grad = True
                    break


        # Combine text and first-level logits
        self.child_classifier = nn.Sequential(
            nn.Linear(embedding_dim + parent_class_count, 384),
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.ReLU(),
            nn.Linear(192, 96),
            nn.ReLU(),
            nn.Linear(96, child_class_count)
        )


    def predict(self, encodings):
        parent_outputs = self.parent_classifier(**encodings)
        parent_class = torch.argmax(parent_outputs.logits, dim=1)
        parent_class_prob = torch.max(torch.softmax(parent_outputs.logits, dim=1)).item() # Get the probability of the predicted class

        parent_embedding = parent_outputs.hidden_states[-1][:, 0, :]

        child_outputs = self.child_classifier(torch.cat([parent_embedding, parent_outputs.logits], dim=1))
        child_class = torch.argmax(child_outputs, dim=1)
        child_class_prob = torch.max(torch.softmax(child_outputs, dim=1)).item() # Get the probability of the predicted class

        return parent_class, parent_class_prob, child_class, child_class_prob

    def forward(self, encodings, parent_label, child_label):
        parent_outputs = self.parent_classifier(**encodings)

        parent_loss = F.cross_entropy(parent_outputs.logits, parent_label)
        # print("Parent Loss:", parent_loss)

        parent_embedding = parent_outputs.hidden_states[-1][:, 0, :]

        child_outputs = self.child_classifier(torch.cat([parent_embedding, parent_outputs.logits], dim=1))
        child_loss = F.cross_entropy(child_outputs, child_label)
        # print("Child Loss:", child_loss)
        return parent_loss + child_loss



In [None]:
import gc

del bms_model
gc.collect()
torch.cuda.empty_cache()

### Training

In [None]:
parent_class_count = len(list(set(parent_label2id.values())))
child_class_count = len(list(set(child_label2id.values())))
epochs = 5
batch_size = 64


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
bms_model = BMSModel(model_name, parent_class_count, child_class_count)
bms_model.to(device)
optimizer = torch.optim.AdamW([
    {"params": [p for p in bms_model.parent_classifier.parameters() if p.requires_grad], "lr": 2e-3},
    {"params": bms_model.child_classifier.parameters(), "lr": 2e-3}])



model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from tqdm.notebook import tqdm

previous_epoch_loss = float('inf')
loss_threshold = 1e-1
patience = 2  # Number of epochs to wait for improvement
no_improvement = 0

for epoch in range(epochs):
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")
    for encodings, parent_labels, child_labels in progress_bar:
        optimizer.zero_grad()
        loss = bms_model.loss(encodings, parent_labels, child_labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss}")

    if abs(previous_epoch_loss - epoch_loss) < loss_threshold:
        no_improvement += 1
        if no_improvement >= patience:
          print(f"Stopping training early at epoch {epoch + 1} due to minimal loss change in {patience} epochs.")
          break
    else:
        no_improvement = 0
    previous_epoch_loss = epoch_loss

Epoch 1/5:   0%|          | 0/49 [00:00<?, ?batch/s]

Epoch 1/5, Loss: 90.66763436794281


Epoch 2/5:   0%|          | 0/49 [00:00<?, ?batch/s]

Epoch 2/5, Loss: 85.50908434391022


Epoch 3/5:   0%|          | 0/49 [00:00<?, ?batch/s]

Epoch 3/5, Loss: 85.2687349319458


Epoch 4/5:   0%|          | 0/49 [00:00<?, ?batch/s]

Epoch 4/5, Loss: 84.72935473918915


Epoch 5/5:   0%|          | 0/49 [00:00<?, ?batch/s]

Epoch 5/5, Loss: 84.7046914100647


In [None]:
torch.save(bms_model.state_dict(), "attempt_1.pt")

In [None]:
ana = "D_95007196"
attr = "PH"

test_enc = tokenizer(ana, attr, padding="max_length", truncation=True, return_tensors="pt")
test_enc = {key: torch.tensor(val).to(device) for key, val in test_enc.items()}

bms_model.eval()

result = bms_model(test_enc)
print(f" parent_class: {result[0]}\n parent_class_prob: {result[1]}\n")# child_class: {result[2]}\n child_class_prob: {result[3]}")

 parent_class: tensor([1], device='cuda:0')
 parent_class_prob: 0.38182011246681213



  test_enc = {key: torch.tensor(val).to(device) for key, val in test_enc.items()}
