# Utils

In [None]:
#| default_exp utils

In [None]:
#| export
def fix_notebook_widgets():
    """
    BUgfix: fastprogress bars not showing in VSCode notebooks.
    Taken from https://github.com/microsoft/vscode-jupyter/issues/13163
    """
    from IPython.display import clear_output, DisplayHandle
    def update_patch(self, obj):
        clear_output(wait=True)
        self.display(obj)
    DisplayHandle.update = update_patch

In [None]:
#| export
import torch
from torch import nn

from fastai.vision.all import *

class Threshold(nn.Module):
    """Classifies 1D inputs into 2 classes, based on whether they surpass a threshold or not"""
    def __init__(self) -> None:
        super().__init__()
        self.t = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        x = x - self.t
        return torch.stack([x, -x], dim=-1)

In [None]:
threshold = Threshold()
with torch.no_grad():
    threshold.t[0] = 3

test_eq(threshold(torch.arange(10)).argmax(1), torch.tensor([1]*3+[0]*7))

Since this is such a simple model, we can fit it to data without iterative optimization:

In [None]:
#| export
@patch
def fit(self: Threshold, x, y):
    """Picks a threshold that maximizes the empirical accuracy"""
    with torch.no_grad():
        def accuracy_for_threshold(t):
            self.t[0] = t
            return accuracy(self(x), y)

        threshold_candidates = np.arange(0.0, 4.0, 0.01)
        self.t[0], accuracy_score = max(((t, accuracy_for_threshold(t)) for t in threshold_candidates),
                                                key=lambda p: p[1])

        return self.t.item(), accuracy_score.item()


In [None]:
threshold = Threshold()
x = torch.randint(high=10, size=(100,))
chosen_threshold, _ = threshold.fit(x, x < 3)

test_close(chosen_threshold, 3, eps=1)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()