In [23]:
import os

import torch
from transformers import AutoTokenizer
from datasets import load_metric, load_dataset
import numpy as np

In [17]:
tokenizer = AutoTokenizer.from_pretrained('t5-small')
tokenizer.pad_token = tokenizer.eos_token
max_input_length = None
max_target_length = None
def preprocess_function(examples):
    inputs = [ex for ex in examples['inputs']]
    targets = [ex for ex in examples['target']]
    model_inputs = tokenizer(
        inputs, max_length=max_input_length, truncation=True, padding=False,
        add_special_tokens=True,
    )

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, max_length=max_target_length, truncation=True, padding=False,
            add_special_tokens=True,
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

class OTTersDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, data):
        tokenizer.pad_token = tokenizer.eos_token
        self.encodings = tokenizer(data['inputs'], padding=True, truncation=True) 
        with tokenizer.as_target_tokenizer():
            self.targets = tokenizer(data['targets'], padding=True, truncation=True)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.targets['input_ids'][idx]
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])


def read_data(data_dir):
    splits = ['train', 'dev', 'test']
    datasets = {}
    for split in splits:
        directory = os.path.join(data_dir, split)
        datasets[split] = load_dataset(directory, data_files=['text.csv'])
        if split != 'test':
            datasets[split] = datasets[split].map(
                preprocess_function,
                batched=True,
                remove_columns=['inputs', 'target'],
            )['train']
        else:
            datasets[split] = datasets[split]['train']
    return datasets['train'], datasets['dev'], datasets['test']

In [39]:
dataset_dir = '../data/out_of_domain'
train_dataset, eval_dataset, test_dataset = read_data(dataset_dir) 

Using custom data configuration train-38be866dbae812fe


Downloading and preparing dataset csv/train to /home/jacky/.cache/huggingface/datasets/csv/train-38be866dbae812fe/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 2616.53it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 326.66it/s]


Dataset csv downloaded and prepared to /home/jacky/.cache/huggingface/datasets/csv/train-38be866dbae812fe/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 436.72it/s]
100%|██████████| 3/3 [00:00<00:00,  7.45ba/s]
Using custom data configuration dev-fe834cd7bb2c5fba


Downloading and preparing dataset csv/dev to /home/jacky/.cache/huggingface/datasets/csv/dev-fe834cd7bb2c5fba/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 1419.87it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 434.01it/s]


Dataset csv downloaded and prepared to /home/jacky/.cache/huggingface/datasets/csv/dev-fe834cd7bb2c5fba/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 528.98it/s]
100%|██████████| 2/2 [00:00<00:00, 29.09ba/s]
Using custom data configuration test-e401684cd5ad938c


Downloading and preparing dataset csv/test to /home/jacky/.cache/huggingface/datasets/csv/test-e401684cd5ad938c/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 2964.17it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 781.94it/s]


Dataset csv downloaded and prepared to /home/jacky/.cache/huggingface/datasets/csv/test-e401684cd5ad938c/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 318.89it/s]


In [40]:
(np.array([len(d) for d in train_dataset['input_ids']] + [len(d) for d in eval_dataset['input_ids']]) <= 50).mean()

1.0

In [41]:
(np.array([len(d) for d in train_dataset['labels']] + [len(d) for d in eval_dataset['labels']]) <= 75).mean()

1.0