In [None]:
%pip install torch
%pip install pytorch-lightning

In [2]:
import random
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [None]:
# Base Data Module for PyTorch Lightning

class BaseDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
        super().__init__()
        
        # Get the dataset using the get_dataset method (to be implemented in subclasses)
        self.ds_x, self.ds_y = self.get_dataset(*args, **kwargs)
        
        # Create a random permutation of indices to shuffle the dataset
        # This ensures that the data is randomly ordered, which is important for training
        shuffler = np.random.permutation(self.ds_x.shape[0])
        
        # Shuffle both input features (ds_x) and labels (ds_y) using the same permutation
        # This maintains the correspondence between inputs and labels
        self.ds_x = self.ds_x[shuffler]
        self.ds_y = self.ds_y[shuffler]
        
        # Set the batch size for data loading
        # This determines how many samples will be processed at once during training
        self.batch_size = batch_size
        
        # Calculate the split index for train/validation sets
        # split is a float between 0 and 1, representing the proportion of data for training
        # This allows for flexible dataset splitting
        self.split = int(self.ds_x.shape[0] * split)

    def train_dataloader(self):
        # Slice the training data from the beginning up to the split index
        ds_X_train, ds_Y_train = self.ds_x[0:self.split], self.ds_y[0:self.split]
        # Create and return a DataLoader with zipped training data and labels
        # This DataLoader will be used by PyTorch Lightning to fetch batches during training
        return torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)

    def val_dataloader(self):
        # Slice the validation data from the split index to the end
        ds_X_test, ds_Y_test = self.ds_x[self.split:], self.ds_y[self.split:]
        # Create and return a DataLoader with zipped validation data and labels
        # This DataLoader will be used by PyTorch Lightning to fetch batches during validation
        return torch.utils.data.DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)
    
class ReverseDataModule(BaseDataModule):
    def get_dataset(self, cnt=1000, seq_len=6):
        # Generate a synthetic dataset for the reverse sequence task
        # cnt: Number of samples in the dataset (default 1000)
        # seq_len: Length of each sequence (default 6)
        
        # Create random integer sequences from 0 to 9
        ds = np.random.randint(0, 10, size=(cnt, seq_len))
        
        # Return two arrays:
        # 1. The original random sequences (ds)
        # 2. The reversed sequences (ds[:, ::-1])
        #    - [:, ::-1] reverses each sequence
        #    - ravel() flattens the array
        #    - reshape() reshapes it back to (cnt, seq_len)
        # The reversed sequences serve as the target for the model to learn
        return ds, ds[:, ::-1].ravel().reshape(cnt, seq_len)

# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb

class AdditionDataModule(BaseDataModule):
    def get_dataset(self):
        # This method generates a dataset for addition problems
        
        ret = []
        # Outer loop: iterates through numbers 0-99 for the first addend
        for i in range(100):
            # Inner loop: iterates through numbers 0-99 for the second addend
            for j in range(100):
                # Calculate the sum of i and j
                s = i + j
                
                # Append a list containing:
                # - Tens digit of i (i // 10)
                # - Ones digit of i (i % 10)
                # - Tens digit of j (j // 10)
                # - Ones digit of j (j % 10)
                # - Hundreds digit of sum (s // 100)
                # - Tens digit of sum ((s // 10) % 10)
                # - Ones digit of sum (s % 10)
                ret.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
        
        # Convert the list of lists to a numpy array for efficient processing
        ds = np.array(ret)
        
        # Return two arrays:
        # 1. Input features: first 6 elements of each row (i and j digits, first two digits of sum)
        # 2. Target labels: second element of each row (ones digit of i, which is the target to predict)
        return ds[:, 0:6], np.copy(ds[:, 1])