<a href="https://colab.research.google.com/github/asxd-10/cis5300_project/blob/main/notebooks/BILSTM_section_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import os
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

CUDA available: True
GPU: Tesla T4


In [None]:
!pip install -q transformers datasets jsonlines scikit-learn
!pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [None]:
!git clone https://github.com/asxd-10/cis5300_project.git

Cloning into 'cis5300_project'...
remote: Enumerating objects: 186, done.[K
remote: Counting objects: 100% (186/186), done.[K
remote: Compressing objects: 100% (162/162), done.[K
remote: Total 186 (delta 93), reused 71 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (186/186), 14.16 MiB | 6.07 MiB/s, done.
Resolving deltas: 100% (93/93), done.


In [None]:
import sys
sys.path.append('cis5300_project')

print('Contents of cis5300_project directory:')
!ls -F cis5300_project/

def load_pubmed_rct(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("###"):
                continue
            label, sentence = line.split("\t", 1)
            data.append((label, sentence))
    return data

def load_pubmed_rct_by_abstract(path):
    """
    Returns a list of abstracts, each abstract is a tuple (labels, sentences)
    """
    abstracts = []
    current_labels = []
    current_sents = []

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith("###"):
                # Start of new abstract
                if current_sents:  # save previous abstract
                    abstracts.append((current_labels, current_sents))
                    current_labels = []
                    current_sents = []
                continue
            label, sentence = line.split("\t", 1)
            current_labels.append(label)
            current_sents.append(sentence)

    # Add last abstract
    if current_sents:
        abstracts.append((current_labels, current_sents))

    return abstracts

print("Loading PubMed RCT data")


train_abstracts = load_pubmed_rct_by_abstract('cis5300_project/data/pubmed_rct/train.txt')
dev_abstracts   = load_pubmed_rct_by_abstract('cis5300_project/data/pubmed_rct/dev.txt')
test_abstracts  = load_pubmed_rct_by_abstract('cis5300_project/data/pubmed_rct/test.txt')

print(f"{len(train_abstracts)} training abstracts")
print(f"{len(dev_abstracts)} dev abstracts")
print(f"{len(test_abstracts)} test abstracts")

Contents of cis5300_project directory:
data/		     notebooks/  requirements.txt  src/
download_scifact.sh  README.md	 setup.sh
Loading PubMed RCT data
15000 training abstracts
2500 dev abstracts
2500 test abstracts


In [None]:
print(train_abstracts[0])

(['OBJECTIVE', 'METHODS', 'METHODS', 'METHODS', 'METHODS', 'METHODS', 'RESULTS', 'RESULTS', 'RESULTS', 'RESULTS', 'RESULTS', 'CONCLUSIONS'], ['To investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( OA ) .', 'A total of @ patients with primary knee OA were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .', 'Outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .', 'Pain was assessed using the visual analog pain scale ( @-@ mm ) .', 'Secondary outcome measures included the Western Ontario and McMaster Universities Osteoarthritis Index scores , patient global assessment ( PGA ) of the severity of knee OA , and @-min walk distance ( @MWD ) .', 'Serum levels of interleukin @ ( IL

In [None]:
# preprocessing- only id mapping and lowercasing for simple baseline

In [None]:
label2id = {
    "BACKGROUND": 0,
    "OBJECTIVE": 1,
    "METHODS": 2,
    "RESULTS": 3,
    "CONCLUSIONS": 4
}
id2label = {v: k for k, v in label2id.items()}

In [None]:
from collections import defaultdict, Counter

def preprocess(text):
    """
    Clean and normalize text.
    - Strip whitespace
    - Lowercase
    - Replace placeholders like '@' with <NUM>
    """
    text = text.strip().lower()
    text = text.replace("@", "<NUM>")
    return text

from collections import Counter

word_counter = Counter()
for labels, sents in train_abstracts:
    for sent in sents:
        sent = preprocess(sent)
        word_counter.update(sent.split())

word2idx = {"<PAD>": 0, "<UNK>": 1}
for i, word in enumerate(word_counter.keys(), start=2):
    word2idx[word] = i

idx2word = {idx: word for word, idx in word2idx.items()}

VOCAB_SIZE = len(word2idx)
print("Vocabulary size:", VOCAB_SIZE)


#Tokenization + Vocabulary

def tokenize_sentence(sentence):
    return sentence.split()

def build_vocab(sentences):
    """
    sentences: list of lists of sentences (abstracts)
    """
    for abstract in sentences:
        for sent in abstract:
            for word in tokenize_sentence(sent):
                _ = word2idx[word]

# preprocessing to abstracts
def encode_abstracts(abstracts):
    """
    abstracts: list of tuples (labels, sentences)
    Returns:
        sentence_tokens: list of list of tokenized sentences
        label_ids: list of list of label ids
    """
    sentence_tokens = []
    label_ids = []

    for labels, sents in abstracts:
        sent_list = [preprocess(s) for s in sents]
        sentence_tokens.append(sent_list)
        label_ids.append([label2id[l] for l in labels])

    return sentence_tokens, label_ids

Vocabulary size: 69734


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

def sentence_to_indices(sentence, word2idx, max_len=None):
    indices = [word2idx.get(w, word2idx["<UNK>"]) for w in sentence]
    if max_len:
        if len(indices) < max_len:
            indices += [word2idx["<PAD>"]] * (max_len - len(indices))
        else:
            indices = indices[:max_len]
    return indices

class PubMedSentenceDataset(Dataset):
    def __init__(self, abstracts, label2id, word2idx, max_sent_len=100):
        """
        abstracts: list of (labels, sentences)
        Each sentence is tokenized and converted to word indices
        """
        self.abstracts = abstracts
        self.label2id = label2id
        self.word2idx = word2idx
        self.max_sent_len = max_sent_len

    def __len__(self):
        return len(self.abstracts)

    def __getitem__(self, idx):
        labels, sents = self.abstracts[idx]
        sent_indices = []
        for sent in sents:
            tokens = preprocess(sent).split()
            idxs = [self.word2idx.get(w, self.word2idx["<UNK>"]) for w in tokens]
            if len(idxs) < self.max_sent_len:
                idxs += [self.word2idx["<PAD>"]] * (self.max_sent_len - len(idxs))
            else:
                idxs = idxs[:self.max_sent_len]
            sent_indices.append(torch.tensor(idxs, dtype=torch.long))

        label_ids = torch.tensor([self.label2id[l] for l in labels], dtype=torch.long)
        return sent_indices, label_ids


# 3. Collate function for variable-length sequences
# def collate_fn(batch):
#     batch_x = [item[0] for item in batch]
#     batch_y = [item[1] for item in batch]
#     lengths = torch.tensor([len(x) for x in batch_x], dtype=torch.long)
#     padded_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True, padding_value=word2idx["<PAD>"])
#     batch_y = torch.stack(batch_y)
#     mask = (padded_x != word2idx["<PAD>"]).to(torch.uint8)
#     return padded_x, batch_y, mask, lengths

def collate_fn(batch):
    """
    batch: list of (list of sentence tensors, label_ids)
    Pads abstracts to same number of sentences
    """
    max_sents = max(len(item[0]) for item in batch)

    padded_sents = []
    padded_labels = []
    mask = []

    for sents, labels in batch:
        # Pad sentences
        pad_count = max_sents - len(sents)
        padded_sents.append(torch.stack(sents + [torch.zeros_like(sents[0])]*pad_count))
        padded_labels.append(torch.cat([labels, torch.full((pad_count,), -1)]))  # -1 for padding labels
        mask.append(torch.tensor([1]*len(sents) + [0]*pad_count, dtype=torch.bool))

    padded_sents = torch.stack(padded_sents)  # (batch_size, seq_len, max_sent_len)
    padded_labels = torch.stack(padded_labels)  # (batch_size, seq_len)
    mask = torch.stack(mask)  # (batch_size, seq_len)

    return padded_sents, padded_labels, mask

BATCH_SIZE = 32

# Example usage:
# train_dataset = PubMedDataset(train_sentences, train_labels, word2idx, max_len=100)
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# dev_dataset = PubMedDataset(dev_sentences, dev_labels, word2idx, max_len=100)
# dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print("Dataset and DataLoader ready.")


Dataset and DataLoader ready.


In [None]:
import torch
import torch.nn as nn
from torchcrf import CRF

class SentenceBiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim//2, num_layers=1,
                            batch_first=True, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, x, tags=None, mask=None):
        """
        x: (batch_size, seq_len, max_sent_len)
        tags: (batch_size, seq_len)
        mask: (batch_size, seq_len)
        """
        batch_size, seq_len, max_sent_len = x.shape
        embeds = self.embedding(x)  # (B, S, L, E)
        sent_embeds = embeds.mean(dim=2)  # average over tokens -> (B, S, E)

        lstm_out, _ = self.lstm(sent_embeds)  # (B, S, H)
        emissions = self.hidden2tag(lstm_out)  # (B, S, num_labels)

        if tags is not None:
            # Only compute loss on non-padded labels
            loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
            return loss
        else:
            pred_tags = self.crf.decode(emissions, mask=mask)
            return pred_tags


# # Example usage
# VOCAB_SIZE = len(word2idx)
# EMBED_DIM = 100
# HIDDEN_DIM = 256
# NUM_LABELS = len(label2id)
# PAD_IDX = word2idx["<PAD>"]

# model = BiLSTM_CRF(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LABELS, PAD_IDX)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# print(model)


In [None]:
# train_sentences, train_labels = encode_abstracts(train_abstracts)
# dev_sentences, dev_labels = encode_abstracts(dev_abstracts)
# test_sentences, test_labels = encode_abstracts(test_abstracts)

In [None]:
train_dataset = PubMedSentenceDataset(train_abstracts, label2id, word2idx, max_sent_len=100)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

dev_dataset = PubMedSentenceDataset(dev_abstracts, label2id, word2idx, max_sent_len=100)
dev_loader = DataLoader(dev_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
model = SentenceBiLSTM_CRF(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LABELS, PAD_IDX)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
from sklearn.metrics import f1_score, accuracy_score

def evaluate(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_x, batch_y, mask in loader:
            batch_x, batch_y, mask = batch_x.to(device), batch_y.to(device), mask.to(device)
            preds = model(batch_x, mask=mask)  # list of lists

            for p, y, m in zip(preds, batch_y, mask):
                valid_len = m.sum().item()
                all_preds.extend(p[:valid_len])
                all_labels.extend(y[:valid_len].tolist())

    acc = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    return acc, macro_f1


In [None]:
EPOCHS = 5
best_f1 = 0.0  # keep track of best dev Macro-F1

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch_x, batch_y, mask in train_loader:
        batch_x, batch_y, mask = batch_x.to(device), batch_y.to(device), mask.to(device)
        optimizer.zero_grad()
        loss = model(batch_x, tags=batch_y, mask=mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    # Evaluate on dev set
    acc, macro_f1 = evaluate(dev_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")
    print(f"Dev Accuracy: {acc:.4f}, Dev Macro-F1: {macro_f1:.4f}")

    if macro_f1 > best_f1:
        best_f1 = macro_f1
        torch.save(model.state_dict(), "best_bilstm_crf_model.pt")
        print(f"Model saved with Macro-F1: {best_f1:.4f}")


Epoch 1/5, Loss: 6.8168
Dev Accuracy: 0.8457, Dev Macro-F1: 0.7709
Model saved with Macro-F1: 0.7709
Epoch 2/5, Loss: 3.6906
Dev Accuracy: 0.8707, Dev Macro-F1: 0.8075
Model saved with Macro-F1: 0.8075
Epoch 3/5, Loss: 2.9439
Dev Accuracy: 0.8871, Dev Macro-F1: 0.8242
Model saved with Macro-F1: 0.8242
Epoch 4/5, Loss: 2.5479
Dev Accuracy: 0.8920, Dev Macro-F1: 0.8344
Model saved with Macro-F1: 0.8344
Epoch 5/5, Loss: 2.2494
Dev Accuracy: 0.8969, Dev Macro-F1: 0.8337


In [None]:
model.load_state_dict(torch.load("best_bilstm_crf_model.pt"))
model.eval()

test_dataset = PubMedSentenceDataset(test_abstracts, label2id, word2idx, max_sent_len=100)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

test_acc, test_macro_f1 = evaluate(test_loader)
print(f"Test Accuracy: {test_acc:.4f}, Test Macro-F1: {test_macro_f1:.4f}")

Test Accuracy: 0.8878, Test Macro-F1: 0.8335
