In [None]:
import torch
from torch.utils.data import Dataset, DataLoader


class TextClassificationCollator:
    def __init__(self, tokenizer, max_length, with_text=True):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.with_text = with_text

    def __call__(self, samples):
        texts = [sample["text"] for sample in samples]
        labels = [sample["label"] for sample in samples]

        encoding = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=self.max_length,
        )

        return_value = {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding["attention_mask"],
            "labels": torch.tensor(labels, dtype=torch.long),
        }

        if self.with_text:
            return_value["text"] = texts

        return return_value


class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        return {
            "text": text,
            "label": label,
        }

In [None]:
import random
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split


def read_text(fn):
    with open(fn, "r") as f:
        lines = f.readlines()

        labels, texts = [], []
        for line in lines:
            if line.strip() != "":
                # The file should have tab delimited two columns.
                # First column indicates label field,
                # and second column indicates text field.
                label, text = line.strip().split("\t")
                labels += [label]
                texts += [text]

    return labels, texts


def get_loaders(fn, tokenizer, valid_ratio=0.2):
    # Get list of labels and list of texts.
    labels, texts = read_text(fn)

    # Generate label to index map.
    unique_labels = list(set(labels))
    label_to_index = {}
    index_to_label = {}
    for i, label in enumerate(unique_labels):
        label_to_index[label] = i
        index_to_label[i] = label

    # Convert label text to integer value.
    labels = list(map(label_to_index.get, labels))

    train_texts, valid_texts, train_labels, valid_labels = train_test_split(
        texts, labels, shuffle=True, test_size=valid_ratio, stratify=labels
    )
    # Get dataloaders using given tokenizer as collate_fn.
    train_loader = DataLoader(
        TextClassificationDataset(train_texts, train_labels),
        batch_size=16,
        shuffle=True,
        collate_fn=TextClassificationCollator(tokenizer, max_length=256),
    )
    valid_loader = DataLoader(
        TextClassificationDataset(valid_texts, valid_labels),
        batch_size=16,
        collate_fn=TextClassificationCollator(tokenizer, max_length=256),
    )

    return train_loader, valid_loader, index_to_label


# Get pretrained tokenizer.
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

# Get dataloaders using tokenizer from untokenized corpus.
train_loader, valid_loader, index_to_label = get_loaders(
    "./data/media.bert.train.tsv", tokenizer, valid_ratio=0.2
)

print(
    "|train| =",
    len(train_loader.dataset),
    "|valid| =",
    len(valid_loader.dataset),
)