In [2]:
from helpers import train, evaluate, load_data_loaders, epoch_time
import torch
from matplotlib import pyplot as plt
from tqdm.notebook import trange
from functools import partial
import time
import torch.optim as optim
import torch.nn as nn
from models import LeNet
%matplotlib inline

Explore data

In [3]:
# # # get some random training images
# dataiter = iter(train_dataloader)
# t_x, t_y = next(dataiter)
# fig, axs = plt.subplots(2, 4, figsize = (16, 8))
# for ax in axs.flatten():
#     t_x, t_y = next(dataiter)
#     ax.imshow(t_x[0].permute(1, 2, 0))
#     ax.set_title('Severity {}'.format(int(t_y[0])))

# len(axs.flatten())

# Training

In [4]:
def main_train(epochs, learning_rate, batch_size=1, num_workers=0, config=None):
    # num_classes = 5

    model = LeNet()

    if torch.backends.mps.is_available():
        device = "mps"
        model = nn.DataParallel(model)
    else:
        device = "cpu"

    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
    train_dataloader, valid_dataloader = load_data_loaders(batch_size=batch_size, num_workers=num_workers)

    best_valid_loss = float('inf')

    train_losses = []
    validation_losses = []
    train_accs = []
    validation_accs = []

    for epoch in trange(epochs, desc="Epochs"):

        start_time = time.monotonic()

        train_loss, train_acc = train(model, train_dataloader, optimizer, criterion, device)

        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # if epoch % 5 == 0:
        valid_loss, valid_acc = evaluate(model, valid_dataloader, criterion, device)

        validation_losses.append(valid_loss)
        validation_accs.append(valid_acc)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'conv-model.pt')

        end_time = time.monotonic()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    # tune.report(loss=(valid_loss), accuracy=valid_acc*100)

In [5]:
learning_rate = 1e-3
EPOCHS = 1
main_train(EPOCHS, learning_rate, batch_size=32, num_workers=8)

8408 images found of 35126 total


Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Training:   0%|          | 0/204 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/73 [00:00<?, ?it/s]

Epoch: 01 | Epoch Time: 4m 33s
	Train Loss: inf | Train Acc: 70.74%
	 Val. Loss: 0.963 |  Val. Acc: 69.44%


PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)