In [1]:
import torch 
import torchvision
import matplotlib.pyplot as plt

from utils import make_from_file, retrieve_setup


  from .autonotebook import tqdm as notebook_tqdm


In [11]:
def train(model, dataset, scheduler, from_file=None, optimizer=None, loss_fn=None, n_epochs=None, batch_size=None):
    """
    Trains a PyTorch model.

    Parameters:
    ----------
    model: 
        PyTorch model

    data: pytorch.DataLoader
        Data.

    scheduler: pytorch.Scheduler
        Scheduler.

    n_epochs: int, greater than 0
        Number of epochs.

    loss_fn: pytorch.loss
        Loss function

    optimizer: pytorch.optim
        Optimizer algorithm.
    
    from_file: path-like
        Configuration file. If a parameter is set to 'None', and 'from_file' is not
        'None', then the parameter will be set to the value from the file. 

    Returns:
    -------
    model: 
        Trained model.

    metrics: dict
        Evaluation and traing metrics.
    """

    # default options
    if from_file is not None:
        dataset_name = dataset["train"].root.split("/")[-1]
        o, l, e, b = make_from_file(from_file, model, dataset_name)

    if optimizer is None:
        optimizer = o 
    if loss_fn is None:
        loss_fn = l 
    if n_epochs is None:
        n_epochs = e 
    if batch_size is None:
        batch_size = b

    # data loader
    data_loader = torch.utils.data.DataLoader(dataset["train"], batch_size, shuffle=True)

    # training
    for epoch in range(n_epochs):
        model.train()
        for X_batch, y_batch in data_loader:
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model

In [12]:
model, dataset = retrieve_setup("MobileNetV3Small", "CIFAR10")
model_trained = train(model, dataset, "bunny", from_file="parameters.yml")


Files already downloaded and verified
Files already downloaded and verified
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torch.Size([4])
torch.Size([4, 10]) torc

False