In [None]:
# Import outside dependencies
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Import in house dependencies
from uwyo_models import ann
import uwyo_dataloader as dl
import uwyo_trainer as trainer

In [None]:
"""
    mean_std(loader)
    A function to compute the mean and standard deviation of an image dataset
    
    inputs:
     - loader (pytorch dataloader): The image dataloader to use for computations
    outputs:
     - mean (tensor): The mean along each channel of the image dataset
     - std (tensor): The standard deviation along each channel of the image dataset
"""
def mean_std(loader):
    images, labels = next(iter(loader))
    mean, std = images.mean([0,2,3]), images.std([0,2,3])
    return mean, std

In [None]:
# Rescaling parameters for input images
image_width = 40
image_height = 30

In [None]:
# A custom pytorch wrapper which can create most model structures
model_creator = ann(name='detector')
model_creator.create_model(model_type='cnn', 
                           inputs=1, 
                           outputs=1, 
                           neurons=[4], 
                           activations=['relu', 'relu', 'sigmoid'], 
                           linear_batch_normalization=False, 
                           linear_dropout=None,
                           cnn_type='2d', 
                           channels=[8], 
                           image_width=image_width, 
                           image_height=image_height,
                           kernels=(11,11),
                           strides=None,
                           paddings=None,
                           pooling='maxpool2d',
                           cnn_batch_normalization=True,
                           cnn_dropout=0.1)

cnn_model = model_creator.model

In [None]:
# Compute the mean and std of the dataset for z-normalization
transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
loader = DataLoader(datasets.ImageFolder(f'Data/', transform=transform), batch_size=245)
mean, std = mean_std(loader)
normalize = transforms.Normalize(mean, std)
print(f'Mean : {mean} | STD : {std}')

In [None]:
# Create a specific training transformer to make training more complex
transform_train = transforms.Compose([transforms.Resize([image_height,image_width]),
                                      transforms.Grayscale(),
                                      transforms.ColorJitter(),
                                      transforms.RandomPerspective(),
                                      transforms.ToTensor(),
                                      normalize])

# Create a specific testing transformer to perform image conversion but without augmentation
transform_tests = transforms.Compose([transforms.Resize([image_height,image_width]),
                                      transforms.Grayscale(),
                                      transforms.ToTensor(),
                                      normalize])

In [None]:
# Use the dataloader from earlier to load the image dataset
batch_size = 32
train, valid, tests, labels = dl.load_images(path=f'Data',
                                             batch_size=batch_size,
                                             image_width=image_width,
                                             image_height=image_height,
                                             transform_train=transform_train,
                                             transform_test=transform_tests,
                                             valid=True,
                                             display=True)

In [None]:
# Train the model
history, cnn_model = trainer.train(cnn_model, 100, train, valid, thresh=0.85)

In [None]:
# Test the model
acc = trainer.test(cnn_model, tests, labels, verbose=1, thresh=0.85)

In [None]:
# Plot the training metrics
trainer.plot_history(history)

In [None]:
# Save the model
model_script = torch.jit.script(cnn_model)
model_script.save('line_follower.pt')