In [12]:
from datasets import load_dataset, concatenate_datasets
import pandas as pd
import os
import lightning.pytorch as pl
from torch.utils.data import DataLoader
import pickle
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq

In [13]:
dataset_name = 'EdinburghNLP/xsum'
dataset = load_dataset(dataset_name,  trust_remote_code=True)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})


In [7]:
dataset['test']['document']

['Prison Link Cymru had 1,099 referrals in 2015-16 and said some ex-offenders were living rough for up to a year before finding suitable accommodation.\nWorkers at the charity claim investment in housing would be cheaper than jailing homeless repeat offenders.\nThe Welsh Government said more people than ever were getting help to address housing problems.\nChanges to the Housing Act in Wales, introduced in 2015, removed the right for prison leavers to be given priority for accommodation.\nPrison Link Cymru, which helps people find accommodation after their release, said things were generally good for women because issues such as children or domestic violence were now considered.\nHowever, the same could not be said for men, the charity said, because issues which often affect them, such as post traumatic stress disorder or drug dependency, were often viewed as less of a priority.\nAndrew Stevens, who works in Welsh prisons trying to secure housing for prison leavers, said the need for ac

In [8]:
dataset['test']['summary']

['There is a "chronic" need for more housing for prison leavers in Wales, according to a charity.',
 'A man has appeared in court after firearms, ammunition and cash were seized by police in Edinburgh.',
 'Four people accused of kidnapping and torturing a mentally disabled man in a "racially motivated" attack streamed on Facebook have been denied bail.',
 'West Brom have appointed Nicky Hammond as technical director, ending his 20-year association with Reading.',
 'The pancreas can be triggered to regenerate itself through a type of fasting diet, say US researchers.',
 'Since their impending merger was announced in January, there has been remarkably little comment about the huge proposed deal to combine Essilor and Luxottica.',
 'A "medal at any cost" approach created a "culture of fear" at British Cycling, says former rider Wendy Houvenaghel.',
 'Have you heard the one about the computer programmer who bought a failing comedy club in Texas and turned it into a million dollar a year bu

In [14]:
class T5SummarizationDataModule(pl.LightningDataModule):
    def __init__(self, model_name, dataset_name, max_length, 
                 batch_size, train_range, val_range, test_range, seed_num):
        super().__init__()
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.max_length = max_length
        self.batch_size = batch_size
        self.train_range = train_range
        self.val_range = val_range
        self.test_range = test_range
        self.seed_num = seed_num
        self.tokenizer = None
        self.data_collator = None
        self.train_dataset = 0
        self.val_dataset = 0
        self.test_dataset = 0
        self.cache_dir = f"./dataset_cache_{self.seed_num}"

    def prepare_data(self):
        # Downloading data, called only once on 1 GPU/TPU in distributed settings
        load_dataset(self.dataset_name,  trust_remote_code=True).shuffle(seed=self.seed_num)
        AutoTokenizer.from_pretrained(self.model_name)

    def setup(self, stage):
        # Setting up the data, called on every GPU/TPU in DDP
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model_name)
        
        # Load and preprocess the dataset
        if stage == 'fit' or stage is None:
            self.train_dataset = self._get_or_process_dataset('train')
            self.val_dataset = self._get_or_process_dataset('val')
        if stage == 'test' or stage is None:
            self.test_dataset = self._get_or_process_dataset('test')
            
    def _get_or_process_dataset(self, split):
        cache_file = os.path.join(self.cache_dir, f"{split}_{self.seed_num}.pkl")
        
        if os.path.exists(cache_file):
            print(f"Loading cached {split} dataset...")
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
        
        print(f"Processing {split} dataset...")
        dataset = load_dataset(self.dataset_name,  trust_remote_code=True).shuffle(seed=self.seed_num)
        
        if split == 'train':
            data = dataset['train'].select(range(min(self.train_range, len(dataset['train']))))
        elif split in ['val', 'test']:
            temp1 = dataset['test']
            temp2 = dataset['validation']
            # concat the two splits
            temp = concatenate_datasets([temp1, temp2]).train_test_split(test_size=0.5, seed=self.seed_num, shuffle=True)
            if split == 'val':
                data = temp['train'].select(range(min(self.val_range, len(temp['train']))))
            else:
                data = temp['test'].select(range(min(self.test_range, len(temp['test']))))
        
        processed_dataset = self._preprocess_dataset(data)
        
        os.makedirs(self.cache_dir, exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump(processed_dataset, f)
        
        return processed_dataset
    
    def _preprocess_dataset(self, dataset):
        return dataset.map(
            lambda x: self._preprocess_function(x),
            batched=True,
            remove_columns=dataset.column_names
        )
        
    def _preprocess_function(self, examples):
        prefix = "summarize: "
        inputs = [prefix + doc for doc in examples["document"]]
        model_inputs = self.tokenizer(inputs, padding="max_length", 
                                      truncation=True, max_length=self.max_length)
        labels = self.tokenizer(text_target=examples["summary"], 
                                padding="max_length", truncation=True, max_length=self.max_length)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.data_collator, shuffle=True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.data_collator, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.data_collator, drop_last=True)


In [15]:
data_module = T5SummarizationDataModule(
    dataset_name='EdinburghNLP/xsum',
    model_name='t5-small',
    max_length=512,
    batch_size=32,
    train_range=20000,
    val_range=5000,
    test_range=5000,
    seed_num=42
)

data_module.prepare_data()
data_module.setup(stage='fit')

Loading cached train dataset...
Processing val dataset...


Map: 100%|██████████| 5000/5000 [00:03<00:00, 1445.30 examples/s]


In [17]:
data_module.setup(stage='validation')

In [None]:
dataloader = next(iter(data_module.train_dataloader()))


In [None]:
data_module.tokenizer.batch_decode(dataloader['input_ids'][0])

['summarize',
 ':',
 'A',
 'tiny',
 'satellite',
 'made',
 'by',
 'Glasgow',
 '-',
 'based',
 'C',
 'ly',
 'de',
 'will',
 'be',
 'used',
 'for',
 '',
 'a',
 'mission',
 'to',
 'create',
 '"',
 'col',
 'd',
 '',
 'atom',
 's',
 '"',
 'in',
 'space',
 '.',
 'The',
 '6',
 'U',
 'Cub',
 'e',
 'S',
 'at',
 'will',
 'carry',
 'quantum',
 '-',
 'based',
 'technology',
 'developed',
 'by',
 'sensor',
 'specialist',
 'Tele',
 'd',
 'y',
 'n',
 'e',
 '',
 'e',
 '2',
 'v',
 'and',
 'the',
 'University',
 'of',
 'Birmingham',
 '.',
 'No',
 'date',
 'has',
 'yet',
 'been',
 'set',
 'for',
 'the',
 'Cold',
 'Atom',
 'Space',
 'Pay',
 'load',
 '(',
 'C',
 'a',
 'spa',
 ')',
 'mission',
 '.',
 'It',
 'hopes',
 'to',
 'replicate',
 'lab',
 'experiments',
 'that',
 'have',
 'shown',
 'cold',
 '',
 'atom',
 's',
 'can',
 'be',
 'used',
 'as',
 '"',
 'ul',
 'tra',
 '-',
 'sensitive',
 'sensors',
 '"',
 'capable',
 'of',
 'mapping',
 'tiny',
 'changes',
 'in',
 'the',
 'strength',
 'of',
 'gravity',
 'acr

In [42]:
data_module.tokenizer.batch_decode(dataloader['labels'][0])

['Mini',
 'a',
 'ture',
 'satellite',
 'maker',
 'C',
 'ly',
 'de',
 'Space',
 'has',
 '',
 'teamed',
 'up',
 'with',
 '',
 'a',
 'tech',
 'con',
 'glomer',
 'ate',
 'on',
 '',
 'a',
 'project',
 'that',
 '',
 'aims',
 'to',
 'create',
 '"',
 'a',
 'new',
 'wave',
 '"',
 'of',
 'space',
 'applications',
 '.',
 '</s>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
