# Various templates for PyTorch applications

## Data prep

### Dataset
[Docs](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)

In [None]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, idx):
        return input_tensor, target_tensor

    def __len__(self):
        return self.len

### Dataloader
[Docs](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    CustomDataset(),
    batch_size=bs,
    pin_memory=True,
    shuffle=True,
    num_workers=4
)

## Network

### nn.Sequential
For simple models where forward() is a sequence of `x = self.layer(x)` calls
<br>
[Docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.Sequential.html)

In [None]:
from torch import nn

class Model(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.layer(params),
            nn.layer(params),
            nn.layer(params)
        )

### nn.Module
For complex models where forward() cannot be built sequentially
<br>
[Demo](https://docs.pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html)
<br>
[Docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html)

In [None]:
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, x):
        return x

## Training

### Schedules

In [None]:
schedule = torch.optim.lr_scheduler.OneCycleLR(
    optim, steps_per_epoch=len(dataloader), epochs=num_epochs,
    max_lr=1e-3
)

schedule = torch.optim.lr_scheduler.ConstantLR(
    optim,
    factor=1.0
)

### Simple Training Loop w/ Fastprogress Output Tables

In [None]:
from fastprogress.fastprogress import master_bar, progress_bar

def train(model, dataloader, optimizer, loss_fn, schedule, clip_grads=None):
    mb = master_bar(range(num_epochs))
    mb.write(["epoch", "train loss"], table=True)
    try:
        for epoch in mb:
            for batch, (x, y) in enumerate(progress_bar(dataloader, parent=mb)):

                x, y = x.to(device), y.to(device)

                # Forward
                with torch.autocast(device):
                    z = model(x)
                    loss = loss_fn(z, y)

                # Backward
                optim.zero_grad()
                loss.backward()
                if clip_grads is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grads)
                optim.step()
                schedule.step()
            mb.write([f"{epoch:<5}", "train loss"], table=True)

    except KeyboardInterrupt:
        print("Training cancelled. Exiting...")