In [5]:
import os
import gzip
import pickle
import random
import wget
import re
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import time

random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x10cdcf770>

In [6]:
IMDB_URL = 'http://dlvu.github.io/data/imdb.{}.pkl.gz'
IMDB_FILE = 'imdb.{}.pkl.gz'

PAD, START, END, UNK = '.pad', '.start', '.end', '.unk'

def load_imdb(final=False, val=5000, seed=0, voc=None, char=False):

    cst = 'char' if char else 'word'

    imdb_url = IMDB_URL.format(cst)
    imdb_file = IMDB_FILE.format(cst)

    if not os.path.exists(imdb_file):
        wget.download(imdb_url)

    with gzip.open(imdb_file) as file:
        sequences, labels, i2w, w2i = pickle.load(file)

    if voc is not None and voc < len(i2w):
        nw_sequences = {}

        i2w = i2w[:voc]
        w2i = {w: i for i, w in enumerate(i2w)}

        mx, unk = voc, w2i['.unk']
        for key, seqs in sequences.items():
            nw_sequences[key] = []
            for seq in seqs:
                seq = [s if s < mx else unk for s in seq]
                nw_sequences[key].append(seq)

        sequences = nw_sequences

    if final:
        return (sequences['train'], labels['train']), (sequences['test'], labels['test']), (i2w, w2i), 2

    # Make a validation split
    random.seed(seed)

    x_train, y_train = [], []
    x_val, y_val = [], []

    val_ind = set( random.sample(range(len(sequences['train'])), k=val) )
    for i, (s, l) in enumerate(zip(sequences['train'], labels['train'])):
        if i in val_ind:
            x_val.append(s)
            y_val.append(l)
        else:
            x_train.append(s)
            y_train.append(l)

    return (x_train, y_train), \
           (x_val, y_val), \
           (i2w, w2i), 2

In [7]:
(x_train, y_train), (x_val, y_val), (i2w, w2i), numcls = load_imdb(final=False)

In [8]:
def pad_and_convert(sequences: List[List[int]], w2i: Dict[str, int],
                   max_length: int = None) -> torch.Tensor:
    """
    Pads a list of sequences to a fixed length and converts them to a PyTorch tensor.

    Args:
        sequences (List[List[int]]): A batch of sequences, where each sequence is a list of integer indices.
        w2i (Dict[str, int]): A dictionary mapping words to their integer indices.
        max_length (int, optional): The length to pad the sequences to. If None, uses the length of the longest sequence in the batch.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, max_length) containing the padded sequences.
    """
    # Retrieve the padding index from the w2i dictionary
    pad_idx = w2i.get('.pad')
    if pad_idx is None:
        raise ValueError("The padding token '.pad' is not found in the w2i dictionary.")

    # Determine the maximum length for padding
    if max_length is None:
        max_length = max(len(seq) for seq in sequences)

    # Initialize a list to hold the padded sequences
    padded_sequences = []

    for seq in sequences:
        # Calculate the number of padding tokens needed
        padding_needed = max_length - len(seq)

        if padding_needed < 0:
            raise ValueError("A sequence is longer than the specified max_length.")

        # Pad the sequence with pad_idx
        padded_seq = seq + [pad_idx] * padding_needed
        padded_sequences.append(padded_seq)

    # Convert the list of padded sequences to a PyTorch tensor with dtype torch.long
    batch_tensor = torch.tensor(padded_sequences, dtype=torch.long)

    return batch_tensor

def create_batches(sequences: List[List[int]], labels: List[int],
                  batch_size: int, w2i: Dict[str, int]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    """
    Splits the data into batches, pads each batch, and converts them to tensors.

    Args:
        sequences (List[List[int]]): List of all sequences.
        labels (List[int]): Corresponding labels for each sequence.
        batch_size (int): Number of samples per batch.
        w2i (Dict[str, int]): Dictionary mapping words to their integer indices.

    Returns:
        List[Tuple[torch.Tensor, torch.Tensor]]: A list of tuples, each containing padded sequences and their labels as tensors.
    """
    batches = []
    total_samples = len(sequences)
    num_batches = (total_samples + batch_size - 1) // batch_size  # Ceiling division

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, total_samples)
        batch_sequences = sequences[start_idx:end_idx]
        batch_labels = labels[start_idx:end_idx]

        # Pad and convert sequences
        padded_sequences = pad_and_convert(batch_sequences, w2i)

        # Convert labels to tensor
        labels_tensor = torch.tensor(batch_labels, dtype=torch.long)

        batches.append((padded_sequences, labels_tensor))

    return batches

In [9]:
batch_size = 64
train_batches = create_batches(x_train, y_train, batch_size, w2i)
val_batches = create_batches(x_val, y_val, batch_size, w2i)