Resources:

- [fastai/fastai/learner.py at master · fastai/fastai](https://github.com/fastai/fastai/blob/master/fastai/learner.py)
- [fastai - Callbacks](https://docs.fast.ai/callback.core.html#trainevalcallback)
- [fastai - Learner, Metrics, Callbacks](https://docs.fast.ai/learner.html#learner)

TODO:

- try with `opt=None` for save/load (pass to both save and load)
- with and without: `Learner.model.eval()`
- with and without: `Learner.model.cuda()` or `.cpu()`

Order of `.predict`

1. creates `dl` using `self.dls.test_dl`
2. calls `self.get_preds`
3. calls `self._do_epoch_validate`
4. sets `self.dl` and calls `self.all_batches` with `torch.no_grad()`
5. calls `self.one_batch` with data from `self.dl`
6. calls `self._do_one_batch`
7. calls `self.model(self.xb)`

Call stack for training

1. call `learn.fit`
2. `self.fit`
3. calls `self._do_fit`
4. calls `self._do_epoch` for each epoch
5. `self._do_epoch_train` also sets `self.dl` and calls `self.all_batches` (also calls `self._do_epoch_validate`)
6. same sequence as shown in `.predict`

`get_preds` always results in `torch.no_grad()`

To consider:

- `Learner.validate`
- `Learner.one_batch` (set in train mode with `TrainEvalCallback.training = ...`
- `Learner.all_batches`
- manually handling transforms (probably never worth the effort)

In [None]:
from functools import partial
from math import radians
from pathlib import Path

from fastai.vision.all import *
from torch import cuda
from tqdm.notebook import tqdm

cuda.set_device(2)
print("Running on GPU:", cuda.current_device())

In [None]:
def print_accuracy_using_predict(learn, image_filenames):
    num_images = len(image_filenames)
    num_correct = 0

    for image_filename in tqdm(image_filenames):
        with learn.no_bar(), learn.no_logging():
            prediction, _, _ = learn.predict(image_filename)
        target = y_from_filename(rotation_threshold, image_filename)
        num_correct += int(prediction == target)

    accuracy = num_correct / num_images
    print(f"Accuracy using `.predict`: {accuracy:.2%}")


def print_accuracy_using_get_preds(learn, image_filenames):
    dl_test = learn.dls.test_dl(image_filenames, with_labels=True)
    # with learn.no_bar(), learn.no_logging():
    _, predictions, targets = learn.get_preds(dl=dl_test, with_decoded=True)
    accuracy = (predictions == targets).float().mean()
    print(f"Accuracy using `.dls.test_dl` and `.get_preds`: {accuracy:.2%}")


def print_accuracy_using_model(learn, image_filenames):
    num_images = len(image_filenames)
    num_correct = 0
    
    dl = learn.dls.test_dl(image_filenames, with_labels=True)
    learn.model.eval()
    learn.model.to("cpu")
    with torch.no_grad():
        for image_batch, label_batch in tqdm(dl):
            image_batch = image_batch.to("cpu")
            label_batch = label_batch.to("cpu")
            output_batch = learn.model(image_batch)
            num_correct += (F.softmax(output_batch, dim=learn.loss_func.axis).argmax(dim=-1) == label_batch).sum()

    accuracy = (num_correct / num_images).item()
    print(f"Accuracy using `.model`: {accuracy:.2%}")


def print_accuracy(learn, image_filenames):
    print_accuracy_using_predict(learn, image_filenames)
    print_accuracy_using_get_preds(learn, image_filenames)
    print_accuracy_using_model(learn, image_filenames)


def y_from_filename(rotation_threshold, filename) -> str:
    """Extracts the direction label from the filename of an image.

    Example: "path/to/file/001_000011_-1p50.png" --> "right"
    """
    filename_stem = Path(filename).stem
    angle = float(filename_stem.split("_")[2].replace("p", "."))

    if angle > rotation_threshold:
        return "left"
    elif angle < -rotation_threshold:
        return "right"
    else:
        return "forward"

In [None]:
data_path = Path("./data/")

image_filenames: list = get_image_files(data_path)
assert len(image_filenames) > 0

rotation_threshold = radians(5)

label_func = partial(y_from_filename, rotation_threshold)

dls = ImageDataLoaders.from_name_func(
    data_path,
    image_filenames,
    label_func,
    valid_pct=0.2,
    shuffle=True,
    bs=64,
    item_tfms=Resize(224),
)

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.fit_one_cycle(2)

In [None]:
by_save_model_function_filename = "_by_save_model_function"


# Saves model state_dict [and optimizer state_dict] using torch.save
save_model(by_save_model_function_filename, learn, learn.opt)

# Loads model state_dict [and optimizer state_dict] using torch.load
# and model.load_state_dict [and optimizer.load_state_dict]
# NOTE: requires dls
learn_by_save_model_function = vision_learner(dls, resnet18, metrics=accuracy)
load_model(by_save_model_function_filename, learn_by_save_model_function, learn.opt)

print_accuracy(learn_by_save_model_function, image_filenames)

In [None]:
by_save_method_filename = "_by_save_method"

learn.save(by_save_method_filename)

# NOTE: requires dls (probably good to use training dls for built-in transforms)
learn_by_save_method = vision_learner(dls, resnet18, metrics=accuracy)
learn_by_save_method.load(by_save_method_filename)

print_accuracy(learn_by_save_method, image_filenames)

In [None]:
by_export_method_filename = "_by_export_method"

learn.export(by_export_method_filename)
learn_by_export_method = load_learner(data_path/by_export_method_filename)

print_accuracy(learn_by_export_method, image_filenames)

In [None]:
by_torch_save_function_filename = "_by_torch_save_function"

torch.save(learn.model.state_dict(), by_torch_save_function_filename)

# NOTE: requires dls (probably good to use training dls for built-in transforms)
learn_by_torch_save_function = vision_learner(dls, resnet18, metrics=accuracy)
learn_by_torch_save_function.model.load_state_dict(torch.load(by_torch_save_function_filename))

print_accuracy(learn_by_torch_save_function, image_filenames)