# Training models on CIFAR-10/100 datasets, using ***torchdistill***

## 1. Make sure you have access to GPU/TPU
Google Colab: *Runtime* -> *Change runtime type* -> *Hardware accelarator*: "GPU" or "TPU"

In [1]:
!nvidia-smi

Mon Feb  5 12:37:49 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   62C    P8              13W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## 2. Install ***torchdistill***

In [2]:
!pip install torchdistill

Collecting torchdistill
  Downloading torchdistill-1.0.0-py3-none-any.whl (93 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.2/93.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchdistill
Successfully installed torchdistill-1.0.0


## 3. Clone ***torchdistill*** repository to use its example code and configuration files

In [3]:
!git clone https://github.com/Moon-kimchi/torchdistill.git

Cloning into 'torchdistill'...
remote: Enumerating objects: 127, done.[K
remote: Counting objects: 100% (2/2), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 127 (delta 0), reused 0 (delta 0), pack-reused 125[K
Receiving objects: 100% (127/127), 247.33 KiB | 8.83 MiB/s, done.
Resolving deltas: 100% (21/21), done.


## 4. Train models on CIFAR-10

Note that the hyperparameters of ResNet were chosen based on either train/val (splitting 50k samples into train:val = 45k:5k) or cross-validation, according to the original papers.  
For the final run (once the hyperparameters are finalized), the authors used all the training images (50k samples).  
- ResNet: https://github.com/facebookarchive/fb.resnet.torch

The following examples demonstrate how to 1) tune hyperparameter and 2) do final-run with ResNet-20 on CIFAR-10 dataset, respectively.

### 4.1 train:val = 45k:5k
Let's start with a small model, ResNet-20, for tutorial.  

Open `torchdistill/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml` and update hyperparameters as you wish e.g., number of epochs (*num_epochs*), batch size (*batch_size* in *train_data_loader* entry), learning rate (*lr* within *optimizer* entry), and so on.
By default, the hyperparameters in the example config are identical to those in the final run config.
  
You will find a lot of module names from [PyTorch documentation](https://pytorch.org/docs/stable/index.html) and [torchvision](https://pytorch.org/docs/stable/torchvision/) such as [`SGD`](https://pytorch.org/docs/stable/optim.html#torch.optim.SGD), [`MultiStepLR`](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.MultiStepLR), [`CrossEntropyLoss`](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss), [`CIFAR10`](https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.CIFAR10), [`RandomCrop`](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomCrop) (, and more). You can update their parameters or replace such modules with other modules in the packages. For instance, `SGD` could be replaced with [`Adam`](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam), and then you will change the parameters under `params` (at least delete `momentum` entry as the parameter is not for `Adam`).

In [4]:
!python torchdistill/examples/torchvision/image_classification.py --config torchdistill/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml --run_log log/cifar10/ce/resnet20-hyperparameter_tuning.log

2024/02/05 12:38:49	INFO	torchdistill.common.main_util	Not using distributed mode
2024/02/05 12:38:49	INFO	__main__	Namespace(config='torchdistill/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml', device='cuda', run_log='log/cifar10/ce/resnet20-hyperparameter_tuning.log', start_epoch=0, seed=None, disable_cudnn_benchmark=False, test_only=False, student_only=False, log_config=False, world_size=1, dist_url='env://', adjust_lr=False)
2024/02/05 12:38:49	INFO	torchdistill.common.main_util	Getting `CIFAR10` from `torchvision.datasets`
2024/02/05 12:38:49	INFO	torchdistill.common.main_util	Calling `CIFAR10` from `torchvision.datasets` with {'kwargs': {'root': '~/datasets/cifar10', 'train': True, 'download': True}}
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/datasets/cifar10/cifar-10-python.tar.gz
100% 170498071/170498071 [00:06<00:00, 28276764.88it/s]
Extracting /root/datasets/cifar10/cifar-10-python.tar.gz to /root/datasets/cifar10
2024/02/05 1