In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import json

In [2]:
class SentenceLabelDataset(Dataset):
    def __init__(self, jsonl_path):
        self.data = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                item = json.loads(line)
                if len(item['sentences']) == len(item['labels']) and len(item['sentences']) > 0:
                    self.data.append(item)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            'pmid': item.get('pmid', ''),
            'sentences': item['sentences'],  # list of str
            'labels': torch.tensor(item['labels'], dtype=torch.float)  # shape: (num_sent,)
        }

# collate: just batch multiple abstracts as list (no padding now)
def collate_fn(batch):
    return batch  # return list of dicts

In [3]:
dataset = SentenceLabelDataset("training_data.jsonl")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# test
for batch in dataloader:
    for sample in batch:
        print("PMID:", sample['pmid'])
        print("Sentences:", sample['sentences'])
        print("Labels:", sample['labels'])
    break

PMID: 28982088
Sentences: ['This work presents an integrated and multi-step approach for the recovery and/or application of the lignocellulosic fractions from corncob in the production of high value added compounds as xylo-oligosaccharides, enzymes, fermentable sugars, and lignin in terms of biorefinery concept.', 'For that, liquid hot water followed by enzymatic hydrolysis were used.', 'Liquid hot water was performed using different residence times (10-50min) and holding temperature (180-200°C), corresponding to severities (log(R']
Labels: tensor([0., 0., 0.])
PMID: 1733571
Sentences: ['The proposed intermediate steps in the relationship between a diet-dependent increase in colonic bile acids and proliferation of colonic cells were studied in rats.', 'Male Wistar rats were fed diets supplemented with increasing amounts of steroids to increase the bile acid concentration of the colon.', 'After 2 weeks, in vivo colonic proliferation was measured using tritiated thymidine incorporation i

In [4]:
from sentence_transformers import SentenceTransformer
import torch

class SentenceEncoder:
    def __init__(self, model_name="michiyasunaga/BioLinkBERT-base", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(model_name)
        self.model = self.model.to(self.device)

    def encode(self, sentence_list):
        """
        Input: a list of sentences (str), length N
        Output: a Tensor (N, hidden_dim)
        """
        embeddings = self.model.encode(
            sentence_list, 
            convert_to_tensor=True, 
            device=self.device,
            show_progress_bar=False
        )
        return embeddings  # shape: (N, hidden_dim)

In [5]:
encoder = SentenceEncoder()

# get one abstract
for batch in dataloader:
    sample = batch[0]  # the first abstract
    sents = sample["sentences"]
    labels = sample["labels"]
    
    embeddings = encoder.encode(sents)
    print("sentence count：", len(sents))
    print("embedding shape:", embeddings.shape)  # (num_sentences, hidden_dim)
    break

No sentence-transformers model found with name michiyasunaga/BioLinkBERT-base. Creating a new one with mean pooling.


config.json:   0%|          | 0.00/559 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/379 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/447k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sentence count： 2
embedding shape: torch.Size([2, 768])


In [6]:
import torch.nn as nn
import torch

class SentenceClassifier(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, num_layers=1, dropout=0.2):
        super(SentenceClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=True)
        
        # Attention layer
        self.attention = nn.Linear(hidden_dim * 2, 1)  # output shape: (batch, sent, 1)

        # Binary classifier per sentence
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, 1)
        )

    def forward(self, sentence_embeddings, attention_mask=None):
        """
        sentence_embeddings: Tensor (num_sents, hidden_dim) — one abstract
        attention_mask: Optional (num_sents,) — 1 for valid sentence, 0 for padding
        """
        # input shape: batch = 1
        x = sentence_embeddings.unsqueeze(0)  # (1, num_sents, input_dim)

        lstm_out, _ = self.lstm(x)  # (1, num_sents, hidden_dim*2)

        # Attention weight calculation
        attn_scores = self.attention(lstm_out).squeeze(-1)  # (1, num_sents)
        attn_weights = torch.softmax(attn_scores, dim=1)    # (1, num_sents)

        weighted = lstm_out.squeeze(0)  # (num_sents, hidden_dim*2)

        logits = self.classifier(weighted).squeeze(-1)  # (num_sents,)
        probs = torch.sigmoid(logits)  # binary probability

        return probs, attn_weights.squeeze(0)  # Return predicted and attention weights

In [7]:
model = SentenceClassifier(input_dim=768).to("cuda")

with torch.no_grad():
    sample_embeds = encoder.encode(sample["sentences"]).to("cuda")
    labels = sample["labels"].to("cuda")

    probs, attn = model(sample_embeds)

    print("Predictive Probability:", probs)
    print("Real label:", labels)
    print("attention weight:", attn)

Predictive Probability: tensor([0.4806, 0.4686], device='cuda:0')
Real label: tensor([0., 0.], device='cuda:0')
attention weight: tensor([0.4934, 0.5066], device='cuda:0')


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score

def train_one_epoch(model, dataloader, encoder, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        for sample in batch:
            sents = sample['sentences']
            labels = sample['labels'].to(device)

            embeddings = encoder.encode(sents).to(device)  # (num_sents, 768)
            preds, _ = model(embeddings)  # (num_sents,)

            if preds.shape != labels.shape:
                print("Shape mismatch:", preds.shape, labels.shape)
                continue

            loss = loss_fn(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.no_grad()
def evaluate(model, dataloader, encoder, loss_fn, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    for batch in dataloader:
        for sample in batch:
            sents = sample['sentences']
            labels = sample['labels'].to(device)

            embeddings = encoder.encode(sents).to(device)
            preds, _ = model(embeddings)

            loss = loss_fn(preds, labels)
            total_loss += loss.item()

            bin_preds = (preds > 0.5).long().cpu().tolist()
            all_preds.extend(bin_preds)
            all_labels.extend(labels.long().cpu().tolist())

    f1 = f1_score(all_labels, all_preds)
    return total_loss / len(dataloader), f1

In [14]:
import os

def train_loop(model, encoder, train_loader, val_loader, device, epochs=5, lr=2e-5, save_path="best_model.pt"):
    loss_fn = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    best_f1 = 0.0

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        train_loss = train_one_epoch(model, train_loader, encoder, optimizer, loss_fn, device)
        val_loss, val_f1 = evaluate(model, val_loader, encoder, loss_fn, device)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}")

        # Save best
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved at epoch {epoch+1} with F1: {val_f1:.4f}")

In [15]:
import random
from torch.utils.data import Subset

def split_dataset(dataset, val_ratio=0.1, seed=42):
    total = len(dataset)
    indices = list(range(total))
    random.seed(seed)
    random.shuffle(indices)

    split = int(total * (1 - val_ratio))
    train_indices = indices[:split]
    val_indices = indices[split:]

    train_set = Subset(dataset, train_indices)
    val_set = Subset(dataset, val_indices)

    return train_set, val_set

def build_dataloaders(jsonl_path, batch_size=4, val_ratio=0.1):
    dataset = SentenceLabelDataset(jsonl_path)
    train_set, val_set = split_dataset(dataset, val_ratio=val_ratio)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader

In [16]:
train_loader, val_loader = build_dataloaders(
    jsonl_path="training_data.jsonl",
    batch_size=4,
    val_ratio=0.1
)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SentenceClassifier().to(device)
encoder = SentenceEncoder(device=device)

train_loop(
    model=model,
    encoder=encoder,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=10,
    lr=2e-5,
    save_path="best_finding_model.pt"
)

No sentence-transformers model found with name michiyasunaga/BioLinkBERT-base. Creating a new one with mean pooling.


Epoch 1/10
Train Loss: 2.0581 | Val Loss: 1.7789 | Val F1: 0.8415
Best model saved at epoch 1 with F1: 0.8415
Epoch 2/10
Train Loss: 1.6356 | Val Loss: 1.6113 | Val F1: 0.8467
Best model saved at epoch 2 with F1: 0.8467
Epoch 3/10
Train Loss: 1.5037 | Val Loss: 1.5396 | Val F1: 0.8506
Best model saved at epoch 3 with F1: 0.8506
Epoch 4/10
Train Loss: 1.4223 | Val Loss: 1.4969 | Val F1: 0.8625
Best model saved at epoch 4 with F1: 0.8625
Epoch 5/10
Train Loss: 1.3654 | Val Loss: 1.4930 | Val F1: 0.8527
Epoch 6/10
Train Loss: 1.3158 | Val Loss: 1.4551 | Val F1: 0.8672
Best model saved at epoch 6 with F1: 0.8672
Epoch 7/10
Train Loss: 1.2710 | Val Loss: 1.4172 | Val F1: 0.8684
Best model saved at epoch 7 with F1: 0.8684
Epoch 8/10
Train Loss: 1.2335 | Val Loss: 1.4205 | Val F1: 0.8677
Epoch 9/10
Train Loss: 1.1989 | Val Loss: 1.4199 | Val F1: 0.8673
Epoch 10/10
Train Loss: 1.1641 | Val Loss: 1.4179 | Val F1: 0.8672
