In [1]:
from argparse import ArgumentParser
import torch as th
import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import LightningCLI
from src.training_module import TrainingModule, DistilledTrainingModule
from src.dataset import DataSetFactory
from pl_bolts.datamodules import CIFAR10DataModule
from src.dataset import CIFAR100DataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import json
import os
import wandb


In [2]:

parser = ArgumentParser()
parser = TrainingModule.add_model_specific_args(parser)
parser.add_argument('--train_teacher', type=bool, default=False)
parser.add_argument('--distill', type=bool, default=False)
parser.add_argument('--student_model', type=str, default='resnet34')
parser.add_argument('--teacher_model', type=str, default='resnet34')
parser.add_argument('--prune_target', type=float, default=0.0)



# trainer arguments
parser.add_argument('--default_root_dir', type=str, default='logs')
parser.add_argument('--max_epochs', type=int, default=240)
parser.add_argument('--gpus', type=int, default=(1 if th.cuda.is_available() else 0))
parser.add_argument('--batch_size', type=int, default=2084)
parser.add_argument('--num_workers', type=int, default=4)
args = parser.parse_args([])

args.num_classes = 100


sweep_config = {
  "name" : "babys-first-sweep",
  "method" : "random",

  "parameters" : {
    "epochs" : {
      "values" : [120]
    },
    "learning_rate" :{
      "min": 0.0001,
      "max": 0.05
    },
    "weight_decay":{
      "min": 1e-5,
      "max": 1e-3
    },
    "precision":{
      "values" : [16, 32]
    },
    "mixup": {
      "values" : [True, False]
    }
  }
}

sweep_id = wandb.sweep(sweep_config, project="cifar100-sweep")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 8a1cn73d
Sweep URL: https://wandb.ai/codestar12/cifar100-sweep/sweeps/8a1cn73d


In [3]:
def train():
    with wandb.init(project='cifar100-sweep') as run:
        config = wandb.config
        training_module = TrainingModule(
            model_name='resnet34',
            image_size=args.image_size,
            num_classes=args.num_classes,
            pre_trained=args.pre_trained,
            lr=config['learning_rate'],
            epochs=config['epochs'],
            mixup=config['mixup'],
            momentum=args.momentum,
            weight_decay=config['weight_decay']
        )
        lr_monitor = LearningRateMonitor(logging_interval="epoch")
        wandb_logger = WandbLogger()
        trainer = pl.Trainer.from_argparse_args(
            args, 
            max_epochs=config['epochs'],
            precision=config['precision'], 
            logger=wandb_logger,
            callbacks=[lr_monitor])
            
        dm = CIFAR100DataModule(batch_size=args.batch_size, num_workers=4, pin_memory=True)
        trainer.fit(training_module, datamodule=dm)

In [4]:

# training_module = TrainingModule(
#     model_name='resnet34',
#     image_size=args.image_size,
#     num_classes=args.num_classes,
#     pre_trained=args.pre_trained,
#     lr=args.lr,
#     momentum=args.momentum,
#     weight_decay=args.weight_decay,
# )
# lr_monitor = LearningRateMonitor(logging_interval="epoch")
# wandb_logger = WandbLogger()
# trainer = pl.Trainer.from_argparse_args(
#         args, 
#         max_epochs=30,
#         precision=16, 
#         logger=wandb_logger,
#         callbacks=[lr_monitor])

# dm = CIFAR100DataModule(batch_size=args.batch_size, num_workers=4, pin_memory=True)
# trainer.fit(training_module, datamodule=dm, logg)

In [5]:
count = 20 # number of runs to execute
wandb.agent(sweep_id, function=train, count=count)

[34m[1mwandb[0m: Agent Starting Run: qurp6jvs with config:
[34m[1mwandb[0m: 	epochs: 120
[34m[1mwandb[0m: 	learning_rate: 0.028864320527081416
[34m[1mwandb[0m: 	precision: 32
[34m[1mwandb[0m: 	weight_decay: 0.0004187876689841033
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcodestar12[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | _model    | ResNet           | 21.3 M
1 | _loss     | CrossEntropyLoss | 0     
2 | val_acc   | Accuracy         | 0     
3 | train_acc | Accuracy         | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.344    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]