In [5]:
!pip install datasets

"""
DATASET STEP: Cleaning and storing splits (train, validation, and test)
"""
from datasets import load_dataset

# Load the test split from Hugging Face
dataset = load_dataset("ccdv/pubmed-summarization")

# ✅ Function to clean and load into a list of (article, abstract) tuples
def clean_pubmed_dataset(dataset, limit=None):
    data = []
    for i, row in enumerate(dataset):
        if row['article'] and row['abstract']:
            article = row['article'].strip().replace('\n', ' ')
            abstract = row['abstract'].strip().replace('\n', ' ')
            data.append((article, abstract))
        if limit and len(data) >= limit:
            break
    return data

# ✅ Load and clean each split
train_data = clean_pubmed_dataset(dataset['train'], limit=None)
val_data = clean_pubmed_dataset(dataset['validation'], limit=None)
test_data = clean_pubmed_dataset(dataset['test'], limit=None)

# ✅ Combine all if needed
pubmed_data = train_data + val_data + test_data

# ✅ Confirm
print(f"Train Samples: {len(train_data)}")
print(f"Validation Samples: {len(val_data)}")
print(f"Test Samples: {len(test_data)}")
print(f"Total Combined Samples: {len(pubmed_data)}")

# ✅ Preview a sample
print("\nSample Article (input):")
print(pubmed_data[0][0][:500])

print("\nSample Abstract (target):")
print(pubmed_data[0][1][:300])


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

README.md:   0%|          | 0.00/3.80k [00:00<?, ?B/s]

train-00000-of-00005.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

train-00001-of-00005.parquet:   0%|          | 0.00/208M [00:00<?, ?B/s]

train-00002-of-00005.parquet:   0%|          | 0.00/207M [00:00<?, ?B/s]

train-00003-of-00005.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

train-00004-of-00005.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/59.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/58.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/119924 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/6633 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6658 [00:00<?, ? examples/s]

Train Samples: 117232
Validation Samples: 6633
Test Samples: 6658
Total Combined Samples: 130523

Sample Article (input):
a recent systematic analysis showed that in 2011 , 314 ( 296 - 331 ) million children younger than 5 years were mildly , moderately or severely stunted and 258 ( 240 - 274 ) million were mildly , moderately or severely underweight in the developing countries .   in iran a study among 752 high school girls in sistan and baluchestan showed prevalence of 16.2% , 8.6% and 1.5% , for underweight , overweight and obesity , respectively .   the prevalence of malnutrition among elementary school aged ch

Sample Abstract (target):
background : the present study was carried out to assess the effects of community nutrition intervention based on advocacy approach on malnutrition status among school - aged children in shiraz , iran.materials and methods : this case - control nutritional intervention has been done between 2008 and


In [19]:

"""
TOKENIZER STEP
"""
from transformers import T5Tokenizer

# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Set max lengths
MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 128

# Function to tokenize
def tokenize_pubmed_data(pairs, tokenizer, max_input_len=512, max_target_len=128):
    inputs = [x[0] for x in pairs]
    targets = [x[1] for x in pairs]

    model_inputs = tokenizer(
        inputs,
        max_length=max_input_len,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )

    labels = tokenizer(
        targets,
        max_length=max_target_len,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
print("Tokenizer pad token ID:", tokenizer.pad_token_id)

Tokenizer pad token ID: 0


In [16]:
import torch
from torch.utils.data import Dataset

class PubMedSummaryDecoderDataset(Dataset):
    def __init__(self, input_pairs, tokenizer, max_input_len=512, max_target_len=128):
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len
        self.data = input_pairs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        article, abstract = self.data[idx]

        # Tokenize input article
        input_enc = self.tokenizer(
            article,
            max_length=self.max_input_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        # Tokenize target abstract
        target_enc = self.tokenizer(
            abstract,
            max_length=self.max_target_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = input_enc["input_ids"].squeeze(0)
        attention_mask = input_enc["attention_mask"].squeeze(0)
        labels = target_enc["input_ids"].squeeze(0)

        # Optional: mask PADs from loss
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }



In [17]:
train_dataset = PubMedSummaryDecoderDataset(pubmed_data, tokenizer)

# Preview one entry
sample = train_dataset[0]
print("Input IDs shape:", sample["input_ids"].shape)
print("Labels shape:", sample["labels"].shape)

Input IDs shape: torch.Size([512])
Labels shape: torch.Size([128])


In [18]:
"""
DATALOADER STEP
"""
from torch.utils.data import DataLoader

# ✅ Collate function to batch and stack tensors
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    input_ids = torch.stack(input_ids)
    attention_masks = torch.stack(attention_masks)
    labels = torch.stack(labels)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }

# ✅ Create the DataLoader
BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

# ✅ Preview a batch
batch = next(iter(train_loader))
print("Batch input_ids shape:", batch["input_ids"].shape)
print("Batch labels shape:", batch["labels"].shape)

Batch input_ids shape: torch.Size([8, 512])
Batch labels shape: torch.Size([8, 128])
