# SetFit SOTA for Bio Text Classification

SetFit is a great practical tool for few shot text classification, but did you know that you can fine-tune a vanilla SetFit for full-shot text classification and outperform models that were pre-trained from scratch using domain data?
Here we show such example in the Biological domain, where SetFit outperforms most of the models that were trained from scratch on Biological data, while being more efficient.

The following table summarizes the results of different models on the HoC* dataset. All of the biological models were first pre-trained using in-domain biological data and in addition were fine-tuned given the HoC training data in the BLUE benchmark. SetFit was not pre-trained using biological data, it is based on a general pre-trained sentence transformer model (MSFT's mpnet) and was solely fine-tuned on the HoC training data. As shown in the table, SetFit surpasses the Bio models and achieves comparable performance to the 347M BioGPT, which is the SOTA model for the Bio domain, while being 3x smaller: https://analyticsindiamag.com/microsoft-launches-biogpt-the-chatgpt-of-lifescience/

| **Model**               | **#params[M]** | **F1**  | **Pre-train Data**          |
|:-----------------------:|:-------:|:---------------:|:-----------------:|
|  **BioBERT[1]**|    110    |   81.5          | Bio     
|  **PubMedBERT[2]**|    340    |   82.7          | Bio   
|    **BioLinkBERT[3]**       |    340   |   84.9          | Bio     
|    **GPT-2**             |    355 |   81.8     | General
|    **BioGPT[4]**      |    347 |   85.1     | Bio
|       **SetFit**       |    105 |   **85.1** | General




Refrences:

[1] Domain-specific
language model pretraining for biomedical natural language
processing" https://arxiv.org/abs/2007.15779

[2] BioBERT: a pre-trained biomedical language representation
model for biomedical text mining" https://arxiv.org/abs/1901.08746

[3] LinkBERT: Pretraining Language Models with Document Links https://arxiv.org/abs/2203.15827

[4] BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining" https://arxiv.org/abs/2210.10341

[5] Automatic semantic classification of scientific literature according to
the hallmarks of cancer. https://academic.oup.com/bioinformatics/article/32/3/432/1743783

[6]  An
evaluation of BERT and ELMo on ten benchmarking
datasets https://arxiv.org/abs/1906.05474


*HoC (the Hallmarks of Cancers corpus) consists of 1580
PubMed abstracts manually annotated at sentence level by
experts with ten currently known hallmarks of cancer [5]. We follow the same training/test split as in [6]

### SetFit Multilabel HoC

In [None]:
!pip install setfit==0.7.0

Load the HoC dataset

In [60]:
!wget https://github.com/ncbi-nlp/BLUE_Benchmark/releases/download/0.1/data_v0.1.zip
!unzip data_v0.1.zip

import pandas as pd
import numpy as np

# Read train/test files
test_df = pd.read_csv('/content/data/hoc/test.tsv', sep='\t')
train_df = pd.read_csv('/content/data/hoc/train.tsv', sep='\t')

--2025-07-19 17:16:51--  https://github.com/ncbi-nlp/BLUE_Benchmark/releases/download/0.1/data_v0.1.zip
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://release-assets.githubusercontent.com/github-production-release-asset/190591943/a5acee80-7695-11ea-99ea-8aeab034b689?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-07-19T18%3A03%3A01Z&rscd=attachment%3B+filename%3Ddata_v0.1.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-07-19T17%3A02%3A13Z&ske=2025-07-19T18%3A03%3A01Z&sks=b&skv=2018-11-09&sig=Tf5FqMSwAOKJhSak3BV1iB3Y%2F4uB%2BZgyjGpcWGVSL8Q%3D&jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc1Mjk0NTcxMSwibmJmIjoxNzUyOTQ1NDExLCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5

In [61]:
LABELS = ['activating invasion and metastasis', 'avoiding immune destruction',
          'cellular energetics', 'enabling replicative immortality', 'evading growth suppressors',
          'genomic instability and mutation', 'inducing angiogenesis', 'resisting cell death',
          'sustaining proliferative signaling', 'tumor promoting inflammation']

In [62]:
# Convert labels to hotvec multilabel format (similar to scikit-learn)
def hotvec_multilabel(true_df):
    data = {}

    for i in range(len(true_df)):
        true_row = true_df.iloc[i]

        key = true_row['index']

        data[key] = set()

        if not pd.isna(true_row['labels']):
            for l in true_row['labels'].split(','):
                data[key].add(LABELS.index(l))

    y_hotvec = []
    for k, (true) in data.items():
        t = [0] * len(LABELS)
        for i in true:
            t[i] = 1

        y_hotvec.append(t)

    y_hotvec = np.array(y_hotvec)

    return y_hotvec

### SetFit Multilabel

In [63]:
from datasets import Dataset
import evaluate
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers import SentenceTransformer

#model = SentenceTransformer(st_model, config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}, trust_remote_code=True)

model = SetFitModel.from_pretrained(
    "BAAI/bge-small-en",
    # "all-MiniLM-L6-v2", "paraphrase-mpnet-base-v2 f1 36.3/84", "BAAI/bge-small-en 40.2/", "NeuML/bioclinical-modernbert-base-embeddings 74.7 81"

    multi_target_strategy="multi-output",     # one-vs-rest; multi-output; classifier-chain
    config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}
)

#"NeuML/pubmedbert-base-embeddings",
#"BAAI/bge-base-en
#"NeuML/bioclinical-modernbert-base-embeddings
#nomic-ai/modernbert-embed-base
#Qwen/Qwen3-Embedding-0.6B
#NeuML/pubmedbert-base-splade # model = SparseEncoder("neuml/pubmedbert-base-splade")

#model = SetFitModel.from_pretrained(model_id, multi_target_strategy="one-vs-rest")

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")

# f1/accuracy sentence level
def compute_metrics(y_pred, y_test):
    return {
        "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
        "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
    }

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [64]:
eval_dataset = Dataset.from_dict({"text": test_df['sentence'], "label": hotvec_multilabel(test_df)})
train_dataset = Dataset.from_dict({"text": train_df['sentence'], "label": hotvec_multilabel(train_df)})

In [65]:
train_dataset_sampled_random = train_dataset.shuffle(seed=1).select(range(1000))

In [None]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric=compute_metrics,
    num_iterations=5
)

In [None]:
trainer.train()
metrics = trainer.evaluate()
print(metrics)

Sci kit classifier

In [67]:
model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")

In [None]:
x_train = train_dataset_sampled_random['text']
y_train = train_dataset_sampled_random['label']

x_test = eval_dataset['text']
y_test = eval_dataset['label']

X_train = model.encode(x_train)
X_test = model.encode(x_test)

In [52]:
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier


# Initialize Logistic Regression model
OneVsRest_classifier = OneVsRestClassifier(LogisticRegression(solver='liblinear')) # Wrap LogisticRegression in OneVsRestClassifier

# Train the classifier
# X_train should be your features (e.g., sentence embeddings)
# y_train should be your multi-hot encoded labels
OneVsRest_classifier.fit(X_train, y_train)

# Predict on the test data
y_pred = OneVsRest_classifier.predict(X_test)

micro_f1 = multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"]
# weighted", "macro"
print(f"Micro F1 Score: {micro_f1}")

Micro F1 Score: 0.001962708537782139


In [53]:
from sklearn.multioutput import ClassifierChain
from sklearn.linear_model import LogisticRegression
import numpy as np

# Initialize ClassifierChain with a base classifier
# You can use LogisticRegression, SGDClassifier, or other binary classifiers
# The order of the chain can be randomized or fixed
base_classifier = LogisticRegression(solver='liblinear')
chain_classifier = ClassifierChain(base_classifier, random_state=42) # random_state for reproducibility of chain order

# Train the ClassifierChain model
chain_classifier.fit(X_train, y_train)

# Predict on the test data
y_pred = chain_classifier.predict(X_test)

micro_f1 = multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"]
# weighted", "macro"
print(f"Micro F1 Score: {micro_f1}")

Micro F1 Score: 0.001962708537782139


### Evaluation of BLUE's HoC F1 (abstract level)

Support functions refactored from https://github.com/ncbi-nlp/BLUE_Benchmark
can be downloaded at https://github.com/ncbi-nlp/BLUE_Benchmark/releases/tag/0.1

In [30]:
def divide(x, y):
    return np.true_divide(x, y, out=np.zeros_like(x, dtype=np.float64), where=y != 0)

def get_p_r_f_arrary(test_predict_label, test_true_label):
    num, cat = test_predict_label.shape
    acc_list = []
    prc_list = []
    rec_list = []
    f_score_list = []
    for i in range(num):
        label_pred_set = set()
        label_gold_set = set()

        for j in range(cat):
            if test_predict_label[i, j] == 1:
                label_pred_set.add(j)
            if test_true_label[i, j] == 1:
                label_gold_set.add(j)

        uni_set = label_gold_set.union(label_pred_set)
        intersec_set = label_gold_set.intersection(label_pred_set)

        tt = len(intersec_set)
        if len(label_pred_set) == 0:
            prc = 0
        else:
            prc = tt / len(label_pred_set)

        acc = tt / len(uni_set)

        rec = tt / len(label_gold_set)

        if prc == 0 and rec == 0:
            f_score = 0
        else:
            f_score = 2 * prc * rec / (prc + rec)

        acc_list.append(acc)
        prc_list.append(prc)
        rec_list.append(rec)
        f_score_list.append(f_score)

    mean_prc = np.mean(prc_list)
    mean_rec = np.mean(rec_list)
    f_score = divide(2 * mean_prc * mean_rec, (mean_prc + mean_rec))
    return mean_prc, mean_rec, f_score

def eval_hoc(true_df, pred_df):
    data = {}

    assert len(true_df) == len(pred_df), \
        f'Gold line no {len(true_df)} vs Prediction line no {len(pred_df)}'

    for i in range(len(true_df)):
        true_row = true_df.iloc[i]
        pred_row = pred_df.iloc[i]
        assert true_row['index'] == pred_row['index'], \
            'Index does not match @{}: {} vs {}'.format(i, true_row['index'], pred_row['index'])

        key = true_row['index'][:true_row['index'].find('_')]
        if key not in data:
            data[key] = (set(), set())

        if not pd.isna(true_row['labels']):
            for l in true_row['labels'].split(','):
                data[key][0].add(LABELS.index(l))

        if not pd.isna(pred_row['labels']):
            for l in pred_row['labels'].split(','):
                data[key][1].add(LABELS.index(l))

    assert len(data) == 315, 'There are 315 documents in the test set: %d' % len(data)

    y_test = []
    y_pred = []
    for k, (true, pred) in data.items():
        t = [0] * len(LABELS)
        for i in true:
            t[i] = 1

        p = [0] * len(LABELS)
        for i in pred:
            p[i] = 1

        y_test.append(t)
        y_pred.append(p)

    y_test = np.array(y_test)
    y_pred = np.array(y_pred)

    r, p, f1 = get_p_r_f_arrary(y_pred, y_test)
    print('Precision: {:.1f}'.format(p*100))
    print('Recall   : {:.1f}'.format(r*100))
    print('F1       : {:.1f}'.format(f1*100))

#### Evaluate on test data

In [54]:
#test_predict_label = trainer.model.predict(test_df['sentence'])
x_TEST = test_df['sentence']
X_TEST = model.encode(x_TEST)
test_predict_label = chain_classifier.predict(X_TEST)


In [55]:
# Convert hotvec multilabel to actual labels
num, cat = test_predict_label.shape
sentence_list = []
for i in range(num):
    sentence_set = set()
    for j in range(cat):
        if test_predict_label[i, j] == 1:
            sentence_set.add(LABELS[j])
    sentence_list.append(','.join(sentence_set))

# Reformat for HoC evaluation
pred_df = test_df
pred_df = pred_df.assign(labels = sentence_list)
pred_df['labels'] = pred_df['labels'].replace({'':np.nan})
test_df['labels'] = test_df['labels'].replace({'':np.nan})

#### Evaluate F1 (abstract level)

In [40]:
#bge_base
eval_hoc(test_df, pred_df)

Precision: 41.2
Recall   : 50.7
F1       : 45.4


In [28]:
# "NeuML/bioclinical-modernbert-base-embeddings
eval_hoc(test_df, pred_df)

Precision: 72.5
Recall   : 73.0
F1       : 72.8


In [34]:
#NeuML/pubmedbert-base-embeddings
eval_hoc(test_df, pred_df)

Precision: 73.3
Recall   : 76.4
F1       : 74.8


In [46]:
#omic-ai/modernbert-embed-base
eval_hoc(test_df, pred_df)

Precision: 44.0
Recall   : 52.7
F1       : 48.0


In [55]:
#omic-ai/modernbert-embed-base
eval_hoc(test_df, pred_df)

Precision: 87.2
Recall   : 84.7
F1       : 86.0


In [56]:
#NeuML/pubmedbert-base-embeddings.  defualt num_itr only 1000 samples
eval_hoc(test_df, pred_df)

Precision: 0.3
Recall   : 0.3
F1       : 0.3


 [seed-42, itr=defualt, 5], seed = 42, seed=0, seed=1

 "NeuML/pubmedbert-base-embeddings",74.7 79.1 81.7 77.3 = 79.4 (85.2) 54.5 sck

"BAAI/bge-base-en 45.4 66 63.8 65.5 = 65.1 (82.5) 0

"NeuML/bioclinical-modernbert-base-embeddings 73.2 78.7 75.9 73.9 = 76.2 (81) 55 sck

nomic-ai/modernbert-embed-base  48 58.7 55.7 57.9 = 57.4 (82.5) 0.3

Qwen/Qwen3-Embedding-0.6B 69.9 64.5 65.4 71.7 = 67.2 (83.3)


emilyalsentzer/Bio_ClinicalBERT  63.1 x x

In [None]:
from sklearn.multioutput import ClassifierChain
from sklearn.linear_model import LogisticRegression
import numpy as np

# Initialize ClassifierChain with a base classifier
# You can use LogisticRegression, SGDClassifier, or other binary classifiers
# The order of the chain can be randomized or fixed
base_classifier = LogisticRegression(solver='liblinear')
chain_classifier = ClassifierChain(base_classifier, random_state=42) # random_state for reproducibility of chain order

# Train the ClassifierChain model
chain_classifier.fit(X_train, y_train)

# Predict on the test data
# The predict method returns multi-hot encoded predictions
y_pred_chain = chain_classifier.predict(X_test)

# Display the shape and some examples of the predictions
print("Shape of multi-hot encoded predictions from ClassifierChain:", y_pred_chain.shape)
print("Example multi-hot encoded predictions from ClassifierChain:")
display(y_pred_chain[:5])

# You can then evaluate y_pred_chain using multi-label metrics
# For example, using the compute_metrics function:
# metrics = compute_metrics(y_pred_chain, y_test)
# print(metrics)