# Install dependencies

In [1]:
%pip install pytorch_lightning
%pip install datasets
%pip install hydra-core
%pip install -U portalocker>=2.0.0
#%pip install -r requirements



General Imports

In [2]:
import os
import tarfile
from pathlib import Path
from typing import Optional, Callable, Tuple, Dict, List, cast

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
from PIL import Image
import pytorch_lightning as pl
import requests

from hydra import utils
from omegaconf import OmegaConf

from datasets import load_dataset, DatasetDict
import logging
import pickle
from transformers import AutoTokenizer

#import torchtext

# CCCN Implementation Path Finder





In [3]:
# There's an empty file in the dataset
PATHFINDER_BLACKLIST = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"}


def pil_loader_grayscale(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        return Image.open(f).convert("L")


class PathFinderDataset(datasets.ImageFolder):
    """Path Finder dataset."""

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(
            root,
            loader=pil_loader_grayscale,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Override this so it doesn't call the parent's method."""
        return [], {}

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        allow_empty: bool = False
    ) -> List[Tuple[str, int]]:
        """Generates a list of samples of a form (path_to_sample, class)."""
        directory = Path(directory).expanduser()

        both_none = extensions is None and is_valid_file is None
        both_something = extensions is not None and is_valid_file is not None
        if both_none or both_something:
            raise ValueError(
                "Both extensions and is_valid_file cannot be None or not None at the same time"
            )

        if extensions is not None:
            def is_valid_file(x: str) -> bool:
                return datasets.folder.has_file_allowed_extension(
                    x, cast(Tuple[str, ...], extensions)
                )
        is_valid_file = cast(Callable[[str], bool], is_valid_file)

        path_list = sorted(
            list((directory / "metadata").glob("*.npy")),
            key=lambda path: int(path.stem),
        )
        if not path_list:
            raise FileNotFoundError(f"No metadata found at {str(directory)}")
        # Get the 'pathfinder32/curv_baseline' part of data_dir
        data_dir_stem = Path().joinpath(*directory.parts[-2:])
        instances = []
        for metadata_file in path_list:
            with open(metadata_file, "r") as f:
                for metadata in f.read().splitlines():
                    metadata = metadata.split()
                    image_path = Path(metadata[0]) / metadata[1]
                    if (
                        is_valid_file(str(image_path))
                        and str(data_dir_stem / image_path) not in PATHFINDER_BLACKLIST
                    ):
                        label = int(metadata[3])
                        instances.append((str(directory / image_path), label))
        return instances



class PathFinderDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        batch_size,
        test_batch_size,
        data_type,
        num_workers,
        pin_memory,
        resolution,
        level="hard",
        val_split=0.1,
        test_split=0.1,
        **kwargs,
    ):
        super().__init__()

        assert resolution in [32, 64, 128, 256]
        assert level in ["easy", "intermediate", "hard"]

        level_dir = {
            "easy": "curv_baseline",
            "intermediate": "curv_contour_length_9",
            "hard": "curv_contour_length_14",
        }[level]

        # Save parameters to self
        data_dir = (
            data_dir + f"/lra_release/pathfinder32/curv_contour_length_14/lra_release/lra_release/pathfinder{resolution}/{level_dir}"
        )
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        self.resolution = resolution
        self.level = level

        self.val_split = val_split
        self.test_split = test_split

        # Determine data_type
        if data_type == "default":
            self.data_type = "image"
            self.data_dim = 2
        elif data_type == "sequence":
            self.data_type = data_type
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 2

        # Create transforms
        train_transform = [
            transforms.ToTensor(),
        ]
        # add augmentations
        if kwargs["augment"]:
            raise NotImplementedError

        self.train_transform = transforms.Compose(train_transform)

    def download_and_extract_lra_release(self, data_dir):
        url = "https://storage.googleapis.com/long-range-arena/lra_release.gz"
        local_filename = os.path.join(data_dir, "lra_release.gz")

        # Create data directory if it doesn't exist
        os.makedirs(data_dir, exist_ok=True)

        # Download the file
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(local_filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)

        # Extract the tar.gz file
        with tarfile.open(local_filename, "r:gz") as tar:
            tar.extractall(path=data_dir)

        # Optionally, remove the tar.gz file after extraction
        os.remove(local_filename)

    def prepare_data(self):
        if not self.data_dir.is_dir():

            self.download_and_extract_lra_release(self.data_dir)
            # raise FileNotFoundError(
            #     f"""
            # Directory {self.data_dir} not found.
            # To get the dataset, download lra_release.gz from
            # https://github.com/google-research/long-range-arena,
            # then unzip it with tar -xvf lra_release.gz.
            # Then point data_dir to the directory that contains pathfinderX, where X is the
            # resolution (either 32, 64, 128, or 256).
            # """
            # )

    def setup(self, stage=None):
        if stage == "test" and hasattr(self, "dataset_test"):
            return
        # [2021-08-18] TD: I ran into RuntimeError: Too many open files.
        # https://github.com/pytorch/pytorch/issues/11201
        torch.multiprocessing.set_sharing_strategy("file_system")
        dataset = PathFinderDataset(self.data_dir, transform=self.train_transform)
        len_dataset = len(dataset)
        val_len = int(self.val_split * len_dataset)
        test_len = int(self.test_split * len_dataset)
        train_len = len_dataset - val_len - test_len
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            dataset,
            [train_len, val_len, test_len],
            generator=torch.Generator().manual_seed(getattr(self, "seed", 42)),
        )

    # we define a separate DataLoader for each of train/val/test
    def train_dataloader(self):
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=True,
        )
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )
        return val_dataloader

    def test_dataloader(self):
        test_dataloader = DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )
        return test_dataloader

    def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.data_type == "sequence":
            # If sequential, flatten the input [B, C, Y, X] -> [B, C, -1]
            x, y = batch
            x_shape = x.shape
            # Flatten
            x = x.view(x_shape[0], x_shape[1], -1)
            batch = x, y
        return batch



# Path finder data set

In [4]:
class PathfinderDataset(torch.utils.data.Dataset):
    """Pathfinder dataset created from a list of images."""

    def __init__(self, data_dir, transform: Optional[Callable] = None) -> None:
        """
        Args:
            img_list (List[Tuple[str, int]]): List of tuples where each tuple contains
                an image path and its corresponding label.
            transform (Optional[Callable]): Optional transformation function or composition of transformations.
        """
        self.data_dir = data_dir
        self.img_list = self.create_imagelist()
        self.transform = transform

    def __len__(self) -> int:
        """Returns the number of samples in the dataset."""
        return len(self.img_list)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, int]: A tuple where the first element is the image tensor
                and the second element is the label.
        """
        img_path, label = self.img_list[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, label

    def create_imagelist(self) -> List[Tuple[str, int]]:

        # root dir where the image are placed
        directory = Path(self.data_dir).expanduser()

        # metadata path where we get the class_idx
        path_list = sorted(
            list((directory / "metadata").glob("*.npy")),
            key=lambda path: int(path.stem),
        )
        instances = []
        for metadata_file in path_list:
            with open(metadata_file, "r") as f:
                for metadata in f.read().splitlines():
                    metadata = metadata.split()
                    image_path = Path(metadata[0]) / metadata[1]
                    label = int(metadata[3])
                    instances.append((str(directory / image_path), label))
        return instances

# Pathfinder data module

In [5]:
class PathfinderDataModule(pl.LightningDataModule):
    def __init__(
        self,
        cfg,
        data_dir,
        batch_size: int = 32,
        test_batch_size: int = 32,
        data_type="default",
        resolution = "32",
        level="hard",
        val_split=0.1,
        test_split=0.1,
    ):
        super().__init__()

        level_dir = {
            "easy": "curv_baseline",
            "intermediate": "curv_contour_length_9",
            "hard": "curv_contour_length_14",
        }[level]

        # Save parameters to self
        data_dir = (
            data_dir + f"/lra_release/pathfinder{resolution}/{level_dir}"
        )
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size

        self.resolution = resolution
        self.level = level

        self.val_split = val_split
        self.test_split = test_split

        self.num_workers = 0  # for google colab training

        # Determine data_type
        if data_type == "default":
            self.data_type = "image"
            self.data_dim = 2
        elif data_type == "sequence":
            self.data_type = data_type
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 2

    def prepare_data(self):
        if not self.data_dir.is_dir():
            self.download_and_extract_lra_release(self.data_dir)

    def download_and_extract_lra_release(self, data_dir):
        url = "https://storage.googleapis.com/long-range-arena/lra_release.gz"
        local_filename = os.path.join(data_dir, "lra_release.gz")

        # Create data directory if it doesn't exist
        os.makedirs(data_dir, exist_ok=True)

        # Download the file
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(local_filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)

        # Extract the tar.gz file
        with tarfile.open(local_filename, "r:gz") as tar:
            tar.extractall(path=data_dir)

        # Optionally, remove the tar.gz file after extraction
        os.remove(local_filename)

    def setup(self, stage=None):
        self._set_transform()
        #self._yaml_parameters()  # TODO set correct params

        self.dataset = PathfinderDataset(self.data_dir, transform=self.transform)
        # compute lengths

        len_dataset = len(self.dataset)
        val_len = int(self.val_split * len_dataset)
        test_len = int(self.test_split * len_dataset)
        train_len = len_dataset - val_len - test_len

        # splits
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.dataset,
            [train_len, val_len, test_len],
            generator=torch.Generator().manual_seed(getattr(self, "seed", 42)),
        )

    def _set_transform(self):

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )

    def _yaml_parameters(self):
      pass

    def train_dataloader(self):
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
        )
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
        return val_dataloader

    def test_dataloader(self):
        test_dataloader = DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
        return test_dataloader

    def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.data_type == "sequence":
            x, y = batch
            x_shape = x.shape
            x = x.view(x_shape[0], x_shape[1], -1)
            batch = x, y
        return batch

# CCCN Implementation ListOps


In [6]:
import os
import pickle
import logging
from pathlib import Path

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

class ListOpsDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        batch_size,
        test_batch_size,
        data_type,
        pin_memory,
        num_workers,
        # Default values taken from S4
        max_length=512,  # Ensure this matches the model's max length
        append_bos=False,
        append_eos=True,
        tokenizer_name="bert-base-uncased",
        **kwargs,
    ):
        super().__init__()
        self.data_dir = Path(data_dir) / "lra_release/pathfinder32/curv_contour_length_14/lra_release/listops-1000"
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        self.max_length = max_length
        self.append_bos = append_bos
        self.append_eos = append_eos

        self.tokenizer_name = tokenizer_name
        self.tokenizer = None

        # Determine data_type
        if data_type == "default":
            self.data_type = "sequence"
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 10

    def prepare_data(self):
        if self.get_cache_dir() is None:
            for split in ["train", "val", "test"]:
                split_path = str(self.data_dir / f"basic_{split}.tsv")
                print(split_path)
        else:  # Process the dataset and save it
            self.process_dataset()

    def setup(self, stage=None):
        if stage == "test" and hasattr(self, "dataset_test"):
            return
        dataset, self.tokenizer = self.process_dataset()
        self.vocab_size = len(self.tokenizer)

        dataset.set_format(type="torch", columns=["input_ids", "Target"])

        # Create all splits
        self.train_dataset, self.val_dataset, self.test_dataset = (
            dataset["train"],
            dataset["val"],
            dataset["test"],
        )

        def collate_batch(batch):
            xs, ys = zip(*[(data["input_ids"], data["Target"]) for data in batch])
            xs = torch.stack(
                [
                    torch.nn.functional.pad(
                        x,
                        [self.max_length - x.shape[-1], 0],
                        value=self.tokenizer.pad_token_id,
                    )
                    for x in xs
                ]
            )
            xs = xs.unsqueeze(1).float()
            ys = torch.tensor(ys)
            return xs, ys

        self.collate_fn = collate_batch

    def process_dataset(self):
        if self.get_cache_dir() is not None:
            return self._load_from_cache()

        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(self.data_dir / "basic_train.tsv"),
                "val": str(self.data_dir / "basic_val.tsv"),
                "test": str(self.data_dir / "basic_test.tsv"),
            },
            delimiter="\t",
            keep_in_memory=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)

        # Adjust tokenizer for BOS and EOS tokens if needed
        if self.append_bos:
            self.tokenizer.add_special_tokens({'additional_special_tokens': ['<bos>']})
        if self.append_eos:
            self.tokenizer.add_special_tokens({'additional_special_tokens': ['<eos>']})

        tokenize = lambda example: {
            "tokens": self.tokenizer(
                example["Source"],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )['input_ids'].squeeze().tolist()  # Convert tensor to list
        }
        dataset = dataset.map(
            tokenize,
            remove_columns=["Source"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        def numericalize(example):
            tokens = (
                (self.tokenizer.convert_tokens_to_ids(['<bos>']) if self.append_bos else []) +
                example["tokens"] +
                (self.tokenizer.convert_tokens_to_ids(['<eos>']) if self.append_eos else [])
            )
            return {"input_ids": tokens}

        dataset = dataset.map(
            numericalize,
            remove_columns=["tokens"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        self._save_to_cache(dataset)
        return dataset, self.tokenizer

    def _save_to_cache(self, dataset):
        cache_dir = self.get_cache_dir()
        os.makedirs(cache_dir, exist_ok=True)
        logger = logging.getLogger(__name__)
        logger.info(f"Saving to cache at {cache_dir}")
        dataset.save_to_disk(cache_dir)
        with open(Path(cache_dir) / "tokenizer.pkl", "wb") as f:
            pickle.dump(self.tokenizer, f)

    def _load_from_cache(self):
        assert self.get_cache_dir().is_dir()
        logger = logging.getLogger(__name__)
        logger.info(f"Load from cache at {self.get_cache_dir()}")
        dataset = load_dataset(self.get_cache_dir())
        with open(Path(self.get_cache_dir()) / "tokenizer.pkl", "rb") as f:
            tokenizer = pickle.load(f)
        return dataset, tokenizer

    @property
    def _cache_dir_name(self):
        return f"max_length-{self.max_length}-append_bos-{self.append_bos}-append_eos-{self.append_eos}"

    def get_cache_dir(self):
        cache_dir = self.data_dir / self._cache_dir_name
        return cache_dir if cache_dir.is_dir() else None

    # We define a separate DataLoader for each of train/val/test
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )


# ListOps

In [18]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import torch
from pathlib import Path
from typing import Callable, Optional, List, Tuple
from transformers import AutoTokenizer
import numpy as np
from datasets import load_dataset

from collections import defaultdict
from typing import List, Dict

from collections import defaultdict
from typing import List



class ListOpsDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        batch_size,
        test_batch_size,
        data_type,
        max_length=512,  # Ensure this matches the model's max length
        append_bos=False,
        append_eos=True,
        tokenizer_name="bert-base-uncased",
    ):
        super().__init__()
        self.data_dir = Path(data_dir) / "lra_release/pathfinder32/curv_contour_length_14/lra_release/listops-1000"
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = 7

        self.max_length = max_length
        self.append_bos = append_bos
        self.append_eos = append_eos

        self.tokenizer_name = tokenizer_name
        self.tokenizer = None

        # Determine data_type
        if data_type == "default":
            self.data_type = "sequence"
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 10

    def prepare_data(self):

        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(self.data_dir / "basic_train.tsv"),
                "val": str(self.data_dir / "basic_val.tsv"),
                "test": str(self.data_dir / "basic_test.tsv"),
            },
            delimiter="\t",
            keep_in_memory=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)

        # Adjust tokenizer for BOS and EOS tokens if needed
        if self.append_bos:
            self.tokenizer.add_special_tokens({'additional_special_tokens': ['<bos>']})
        if self.append_eos:
            self.tokenizer.add_special_tokens({'additional_special_tokens': ['<eos>']})

        tokenize = lambda example: {
            "tokens": self.tokenizer(
                example["Source"],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )['input_ids'].squeeze().tolist()  # Convert tensor to list
        }
        dataset = dataset.map(
            tokenize,
            remove_columns=["Source"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        def numericalize(example):
            tokens = (
                (self.tokenizer.convert_tokens_to_ids(['<bos>']) if self.append_bos else []) +
                example["tokens"] +
                (self.tokenizer.convert_tokens_to_ids(['<eos>']) if self.append_eos else [])
            )
            return {"input_ids": tokens}

        dataset = dataset.map(
            numericalize,
            remove_columns=["tokens"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        self.dataset = dataset

    def setup(self, stage=None):
        self._set_transform()
        self._yaml_parameters()  # TODO set correct params

        self.vocab_size = len(self.vocab)
        self.dataset.set_format(type="torch", columns=["input_ids", "Target"])

        self.train_dataset, self.val_dataset, self.test_dataset = (
            self.dataset["train"],
            self.dataset["val"],
            self.dataset["test"],
        )

        def collate_batch(batch):
            input_ids = [data["input_ids"] for data in batch]
            labels = [data["Target"] for data in batch]

            pad_value = self.vocab.pad_index

            padded_input_ids = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(seq) for seq in input_ids], batch_first=True, padding_value=pad_value
            )

            if padded_input_ids.size(1) > self.max_length:
                padded_input_ids = padded_input_ids[:, -self.max_length:]  # truncate to max_length
            else:
                padding_size = self.max_length - padded_input_ids.size(1)
                padded_input_ids = torch.nn.functional.pad(
                    padded_input_ids,
                    (padding_size, 0),  # pad on the left
                    value=pad_value,
                )

            input_tensor = padded_input_ids.float()
            label_tensor = torch.tensor(labels)

            return input_tensor, label_tensor

        self.collate_fn = collate_batch

    def _set_transform(self):
        self.transform = transforms.Compose([transforms.ToTensor()])

    def _yaml_parameters(self):
      pass

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )


if __name__ == "__main__":
    pass


# CCCN Implementation IMDB


In [8]:
import os
import pickle
import logging
from pathlib import Path

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

class IMDBDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        batch_size,
        test_batch_size,
        data_type,
        pin_memory,
        num_workers,
        # Default values taken from S4
        max_length=512,  # Ensure this matches the model's max length
        tokenizer_name="bert-base-uncased",
        vocab_min_freq=15,
        append_bos=False,
        append_eos=True,
        val_split=0.0,
        **kwargs,
    ):
        super().__init__()

        # Save parameters to self
        self.data_dir = Path(data_dir) / "IMDB"
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        self.max_length = max_length
        self.tokenizer_name = tokenizer_name
        self.vocab_min_freq = vocab_min_freq
        self.append_bos = append_bos
        self.append_eos = append_eos
        self.val_split = val_split

        self.tokenizer = None
        self.cache_dir = self.get_cache_dir()

        # Determine data_type
        if data_type == "default":
            self.data_type = "sequence"
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 2

    def prepare_data(self):
        if self.cache_dir is None:  # Just download the dataset
            load_dataset("imdb", cache_dir=self.data_dir)
            self.process_dataset()

    def setup(self, stage=None):
        if stage == "test" and hasattr(self, "dataset_test"):
            return
        dataset, self.tokenizer = self.process_dataset()
        self.vocab_size = len(self.tokenizer)

        dataset.set_format(type="torch", columns=["input_ids", "label"])

        # Create all splits
        self.train_dataset, self.test_dataset = dataset["train"], dataset["test"]
        if self.val_split == 0.0:
            # Use test set as val set, as done in the LRA paper
            self.val_dataset = self.test_dataset
        else:
            train_val = self.train_dataset.train_test_split(
                test_size=self.val_split,
                seed=getattr(self, "seed", 42),
            )
            self.train_dataset, self.val_dataset = train_val["train"], train_val["test"]

        def collate_batch(batch):
            xs, ys = zip(*[(data["input_ids"], data["label"]) for data in batch])
            xs = torch.stack(
                [
                    torch.nn.functional.pad(
                        x,
                        [self.max_length - len(x), 0],
                        value=self.tokenizer.pad_token_id,
                    )
                    for x in xs
                ]
            )
            xs = xs.unsqueeze(1).float()
            ys = torch.tensor(ys)
            return xs, ys

        self.collate_fn = collate_batch

    def process_dataset(self):
        if self.get_cache_dir() is not None:
            return self._load_from_cache()

        dataset = load_dataset("imdb", cache_dir=self.data_dir)

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
        self.tokenizer.add_special_tokens({'additional_special_tokens': ['<bos>', '<eos>']})

        def tokenize_function(example):
            encoding = self.tokenizer(
                example["text"],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            return {
                "input_ids": encoding['input_ids'].squeeze().tolist()  # Convert tensor to list
            }

        # Tokenize and map to dataset
        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        self._save_to_cache(tokenized_datasets)
        return tokenized_datasets, self.tokenizer

    def _save_to_cache(self, dataset):
        cache_dir = self.get_cache_dir()
        os.makedirs(cache_dir, exist_ok=True)
        logger = logging.getLogger(__name__)
        logger.info(f"Saving to cache at {cache_dir}")
        dataset.save_to_disk(cache_dir)
        with open(Path(cache_dir) / "tokenizer.pkl", "wb") as f:
            pickle.dump(self.tokenizer, f)

    def _load_from_cache(self):
        assert self.get_cache_dir().is_dir()
        logger = logging.getLogger(__name__)
        logger.info(f"Load from cache at {self.get_cache_dir()}")
        dataset = DatasetDict.load_from_disk(self.get_cache_dir())
        with open(Path(self.get_cache_dir()) / "tokenizer.pkl", "rb") as f:
            tokenizer = pickle.load(f)
        return dataset, tokenizer

    @property
    def _cache_dir_name(self):
        return f"l_max-{self.max_length}-tokenizer-{self.tokenizer_name}-min_freq-{self.vocab_min_freq}-append_bos-{self.append_bos}-append_eos-{self.append_eos}"

    def get_cache_dir(self):
        cache_dir = self.data_dir / self._cache_dir_name
        return cache_dir if cache_dir.is_dir() else None

    # We define a separate DataLoader for each of train/val/test
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )


# IMDB

In [21]:
class IMDBDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        batch_size,
        test_batch_size,
        data_type,
        max_length=4096,
        tokenizer_type="word",
        tokenizer_name="bert-base-uncased",
        vocab_min_freq=15,
        append_bos=False,
        append_eos=True,
        val_split=0.0,
    ):
        assert tokenizer_type in [
            "word",
            "char",
        ], f"tokenizer_type {tokenizer_type} not supported"

        super().__init__()

        # Save parameters to self
        self.data_dir = Path(data_dir) / "IMDB"
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = 7

        self.max_length = max_length
        self.tokenizer_type = tokenizer_type
        self.vocab_min_freq = vocab_min_freq
        self.append_bos = append_bos
        self.append_eos = append_eos
        self.val_split = val_split
        self.tokenizer_name = tokenizer_name

        # Determine data_type
        if data_type == "default":
            self.data_type = "sequence"
            self.data_dim = 1
        else:
            raise ValueError(f"data_type {data_type} not supported.")

        # Determine sizes of dataset
        self.input_channels = 1
        self.output_channels = 2

    def prepare_data(self):

        dataset = load_dataset("imdb", cache_dir=self.data_dir)

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
        self.tokenizer.add_special_tokens({'additional_special_tokens': ['<bos>', '<eos>']})

        def tokenize_function(example):
            encoding = self.tokenizer(
                example["text"],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            return {
                "input_ids": encoding['input_ids'].squeeze().tolist()  # Convert tensor to list
            }

        # Tokenize and map to dataset
        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.num_workers,
        )

        self.dataset = tokenized_datasets



    def setup(self, stage=None):

        self._set_transform()
        self._yaml_parameters()

        self.vocab_size = len(self.vocab)
        self.dataset.set_format(type="torch", columns=["input_ids", "label"])

        self.train_dataset, self.test_dataset = (
            self.dataset["train"],
            self.dataset["test"],
        )

        def collate_batch(batch):
            input_ids = [data["input_ids"] for data in batch]
            labels = [data["label"] for data in batch]

            pad_value = float(self.vocab["<pad>"])

            padded_input_ids = torch.nn.utils.rnn.pad_sequence(
                input_ids, batch_first=True, padding_value=pad_value
            )

            if padded_input_ids.size(1) > self.max_length:
                padded_input_ids = padded_input_ids[
                    :, -self.max_length :
                ]  # truncate to max_length
            else:
                # Pad to max_length on the left (if needed)
                padding_size = self.max_length - padded_input_ids.size(1)
                padded_input_ids = torch.nn.functional.pad(
                    padded_input_ids,
                    (padding_size, 0),  # pad on the left
                    value=pad_value,
                )

            input_tensor = padded_input_ids.unsqueeze(1).float()

            label_tensor = torch.tensor(labels)

            return input_tensor, label_tensor

        self.collate_fn = collate_batch

    def _set_transform(self):

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )

    def _yaml_parameters(self):
        pass

    def train_dataloader(self):
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            collate_fn=self.collate_fn,
        )
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )
        return val_dataloader

    def test_dataloader(self):
        test_dataloader = DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )
        return test_dataloader


if __name__ == "__main__":
    pass


# Main code run

In [22]:
# prompt: Generate the code to instantiate PathFinderDataModule

choice = "imdb"
if choice == "path_finder":
  dm = PathFinderDataModule(data_dir='./', batch_size=32, test_batch_size=32, data_type='default', num_workers=2, pin_memory=False, resolution=32,augment = False)
  dm.prepare_data()
  dm.setup()



  import matplotlib.pyplot as plt

  # prompt: Generate the code to instantiate PathFinderDataModule
  dm = PathfinderDataModule(data_dir='./', batch_size=32, test_batch_size=32)
  dm.prepare_data()
  dm.setup()
  # Fetch a sample from the training dataset
  sample_idx = 0  # Index of the sample you want to print
  sample = dm.train_dataset[sample_idx]

  # If the sample is a tuple (image, label)
  if isinstance(sample, tuple):
      image, label = sample
  else:
      image = sample
      label = None

  # Print label
  if label is not None:
      print(f"Label: {label}")

  # Print image
  # Assuming image is a PIL image or a tensor that can be converted to PIL
  if isinstance(image, torch.Tensor):
      image = transforms.ToPILImage()(image)

  plt.imshow(image)
  plt.title(f"Sample {sample_idx}")
  plt.axis('off')  # Hide axis
  plt.show()


elif choice == "listops":
  dm = ListOpsDataModule(data_dir='./', batch_size=32, test_batch_size=32, data_type='default')
  dm.prepare_data()
  dm.setup()
  # Retrieve and print a sample
  train_loader = dm.train_dataloader()

  # Get a batch of data
  for images, labels in train_loader:
      print(f"Batch of images shape: {images.shape}")
      print(f"Batch of labels: {labels}")

      # Print the first image and label in the batch
      print(f"First image tensor: {images[0]}")
      print(f"First label: {labels[0]}")

      # Break after first batch for demonstration
      break

elif choice == "imdb":
  dm = IMDBDataModule(data_dir='./', batch_size=32, test_batch_size=32, data_type='default')
  dm.prepare_data()
  dm.setup()
  # Retrieve and print a sample
  train_loader = dm.train_dataloader()

  # Get a batch of data
  for images, labels in train_loader:
      print(f"Batch of images shape: {images.shape}")
      print(f"Batch of labels: {labels}")

      # Print the first image and label in the batch
      print(f"First image tensor: {images[0]}")
      print(f"First label: {labels[0]}")

      # Break after first batch for demonstration
      break

Map (num_proc=7):   0%|          | 0/25000 [00:00<?, ? examples/s]

Map (num_proc=7):   0%|          | 0/25000 [00:00<?, ? examples/s]

Process ForkPoolWorker-15:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/multiprocess/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/usr/local/lib/python3.10/dist-packages/datasets/utils/py_utils.py", line 678, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 3581, in _map_single
    writer.write_batch(batch)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_writer.py", line 569, in write_batch
    inferred_features[col] = typed_sequence.get_inferred_type()
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_writer.py", line 133, in get_infe

TimeoutError: 