To implement any kind of neural network in PyTorch, we must phrase the problem as an optimization problem.

In [6]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.utils.data
from tqdm.autonotebook import tqdm 
#from idlmam import * 

In [7]:
# Helper move to function

def moveTo(obj, device): 
    """ 
    obj: the python object to move to a device, or to move its
    ➥ contents to a device
    device: the compute device to move objects to 
    """
    if isinstance(obj, list): 
        return [moveTo(x, device) for x in obj] 
    elif isinstance(obj, tuple): 
        return tuple(moveTo(list(obj), device)) 
    elif isinstance(obj, set): 
        return set(moveTo(list(obj), device)) 
    elif isinstance(obj, dict): 
        to_ret = dict() 
        for key, value in obj.items(): 
            to_ret[moveTo(key, device)] = moveTo(value, device) 
        return to_ret 
    elif hasattr(obj, "to"): 
        return obj.to(device) 
    else: 
        return obj

In [8]:
# The code for the simple training loop 

def train_simple_network(model, loss_func, training_loader, epochs=20, device="mps"):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # Creates the optimizer and moses the model to the device

    model.to(device) # Places the model on the correct compute resource

    for epoch in tqdm(range(epochs), desc="Epochs"): # For each epoch
        model = model.train() # Sets the model to training mode
        running_loss = 0.0

        for inputs, labels in tqdm(training_loader, desc="Batch", leave=False): # For each batch
            # Move the inputs and labels to the correct device
            inputs = moveTo(inputs, device)
            labels = moveTo(labels, device)     

            # Zero the gradients - cleans up the gradients from the previous batch
            optimizer.zero_grad()

            y_hat = model(inputs) # Forward pass, computes f(xi) of Theta

            loss = loss_func(y_hat, labels) # Computes the loss
            loss.backward() # Computes the gradients
            optimizer.step() # Updates the parameters
            running_loss += loss.item()