In [6]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger

Debuggin neural networks can be very difficult. Pytorch lightning provides a few tools to help with this.

Some useful ones (among others) that will be explored are the trainer flags:
 - fast_dev_run=True (this will not save checkpoints or log anything, in stead this will "touch" every part of your code from training to validation to test and just make sure it runs)
 - overfit_batch=int|float (this will take a single batch of your training data and train on it over and over in an attempt to overfit on it. If the model is not able to overfit on a single batch of data, then there is likely a problem with the model) set it to an iteger (e.g. 1) then it will use one batch. You can also set it to a float (e.g. 0.1) and it will use 10% of the data.
 - num_sanity_val_steps=int (this will run the validation loop on the first n batches of the validation set. This is useful to make sure that the validation loop is working as expected)

Set a breakpoint in the code using the following command:
    import pdb; pdb.set_trace()

In [7]:
import pdb

In [12]:
def generate_cont_xor_data(num_points):
    """Generate a random XOR dataset with two continuous features."""
    x = torch.rand(num_points, 2)
    y = torch.logical_xor(x[:, 0] > 0.5, x[:, 1] > 0.5).long()
    return x, y

x, y = generate_cont_xor_data(1000)

class XORDataset(Dataset):
    def __init__(self, num_points):
        self.x, self.y = generate_cont_xor_data(num_points)

    def __getitem__(self, index):
        return {
            "point": self.x[index],
            "label": self.y[index]
        }

    def __len__(self):
        return len(self.x)
    
loader = DataLoader(XORDataset(10000), batch_size=32)
for item in loader:
    points, labels = item["point"], item["label"]
    print(points.shape, labels.shape)
    break

torch.Size([32, 2]) torch.Size([32])


In [None]:
# lets create a simple model so we can see the debuggin process in action

class XORModel(pl.LightningModule):
    """This model predicts the XOR of two real inputs (continuous XOR)"""