# [Do Deep Nets Really Need to be Deep?](https://arxiv.org/pdf/1312.6184)

## Intro

* Since it has been shown we can mimic the function learned by complex model with a small net, the function learned by complex model wasn't truly too complex to be learned by the small net
    * Showing that shallow models are capable of learning the same function as deep nets debunks the myth that function learned by deep net has to be deep. <br><br>    
* It is better to train a student model on logits since different logits can map to same distribution when using softmax (technically losing information the complex model learned)
    * Also, softmax can lead to few large values relative to others, which would cause cross entropy to focus on them, ignoring others <br><br>
* Shallow and wide models are slower for learning since there are many highly correlated features, so gradient descent converges slowly
    * One remedy is linear bottleneck (weight matrix factorization), which only increases speed and doesn't increase representational power <br><br>
* Why training on teacher model's prediction could be better than training directly on original dataset can be due to:
    * Teacher model can eliminate label errors by predicting those correctly (from generalization)
    * Teacher model soft targets provide more information than hard targets such as confusable classes <br><br>
* Model compression works best when unlabeled data set is much larger than train set (to reduce gap between teacher and student) and when unlabeled examples aren't training points (teacher model is more likely to have overfit these)

In [48]:
import torch
torch.nn.functional.softmax(torch.tensor([-10.0, 0.0, 10.0]), dim=-1)

tensor([2.0611e-09, 4.5398e-05, 9.9995e-01])

In [50]:
torch.nn.functional.softmax(torch.tensor([10.0, 20.0, 30.0]), dim=-1) # Different logits same softmax output

tensor([2.0611e-09, 4.5398e-05, 9.9995e-01])

Say target is $[3.0385e^{-7}, 6.6928e^{-3}, 9.9331e^{-1}]$ and prediction is $[\frac{1}{3}, \frac{1}{3}, \frac{1}{3}]$

$CE_{Loss} \approx -[3.0e^{-7} \ log(\frac{1}{3}) + 6.7e^{-3} \ log(\frac{1}{3}) + 9.9e^{-1} \ log(\frac{1}{3})] = $

We can see most of the loss would come from largest target so model would focus on getting that right and ignoring others targets.