In [None]:
# default_exp vision.metrics

# Vision.Metrics

> Metrics for tracking performance of self-supervised training during pretraining. Aims to give an idea about the quality of the learned representations during training in the presence of a labeled validation set.

In [None]:
#export
from fastai.vision.all import *

In [None]:
#export
class KNNProxyMetric(Callback):
    "knn@1 accuracy for validation"
    order,run_train,run_valid=8,False,True
            
    def before_batch(self):
        self.orig_x, self.orig_y = self.x, self.y
    
    def before_validate(self):
        self.embs = tensor([]).to(self.dls.device)
        self.targs = tensor([]).to(self.dls.device)
        
    def after_pred(self):
        self.embs = torch.cat([self.embs, self.model.encoder(self.orig_x)])
        self.targs = torch.cat([self.targs, self.orig_y])
  
    def accuracy(self): 
        self.embs = F.normalize(self.embs)
        sim = self.embs @ self.embs.T
        nearest_neighbor = sim.argsort(dim=1, descending=True)[:,2]
        return (self.targs == self.targs[nearest_neighbor]).float().mean()
        
    def after_fit(self):
        del self.embs, self.targs
        torch.cuda.empty_cache()

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01 - augmentations.ipynb.
Converted 02 - layers.ipynb.
Converted 03 - distributed.ipynb.
Converted 10 - simclr.ipynb.
Converted 11 - moco.ipynb.
Converted 12 - byol.ipynb.
Converted 13 - swav.ipynb.
Converted 14 - barlow_twins.ipynb.
Converted 20 - clip.ipynb.
Converted 21 - clip-moco.ipynb.
Converted 70 - vision.metrics.ipynb.
Converted index.ipynb.
