# Hyperparameter Tuning Using WandDB Sweeps

In [1]:
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from models import CNNStudent, CNNTeacher, Resnet18, Resnet34, Resnet50
import wandb
import argparse
from time import perf_counter

# function to extract cifar-10 datasets and apply transformations
def get_cifar10_datasets(img_size=32):
    train_transforms_cifar = transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(0.1),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.75, scale=(0.02, 0.1), value=1.0, inplace=False)
    ])

    test_transforms_cifar = transforms.Compose([
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # load and download data
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms_cifar)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms_cifar)

    return train_dataset, test_dataset

def soft_target_loss(soft_targets, soft_prob, temperature):
    return -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (temperature**2)

# function to train model for 1 epoch using cross-entropy and soft-target loss
def distill_knowledge_ce_stl(teacher, student, optimizer, train_loader, device, temperature, ce_weight, st_weight):
    student.train()
    teacher.eval()

    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
        with torch.no_grad():
            teacher_logits, _ = teacher(inputs)

        # Forward pass with the student model
        student_logits, _ = student(inputs)

        #Soften the student logits by applying softmax first and log() second
        soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
        soft_prob = nn.functional.log_softmax(student_logits / temperature, dim=-1)

        # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
        soft_targets_loss = soft_target_loss(soft_targets, soft_prob, temperature)

        # Calculate the true label loss
        ce_loss = nn.CrossEntropyLoss()
        label_loss = ce_loss(student_logits, labels)

        # Weighted sum of the two losses
        loss = st_weight * soft_targets_loss + ce_weight * label_loss

        loss.backward()
        optimizer.step()
        wandb.log({"batch_train_loss": loss.item(), "batch_train": distill_knowledge_ce_stl.train_batch_counter})
        distill_knowledge_ce_stl.train_batch_counter += 1
        running_loss += loss.item()

    return running_loss / len(train_loader)

def distill_knowledge_ce_csl(teacher, student, optimizer, train_loader, device, ce_weight, cs_weight):
    student.train()
    teacher.eval()

    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
        with torch.no_grad():
            _, teacher_feats = teacher(inputs)

        # Forward pass with the student model
        student_logits, student_feats = student(inputs)

        # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
        cosine_loss = nn.CosineEmbeddingLoss()
        feats_loss = cosine_loss(student_feats, teacher_feats, target=torch.ones(inputs.size(0)).to(device))

        # Calculate the true label loss
        ce_loss = nn.CrossEntropyLoss()
        label_loss = ce_loss(student_logits, labels)

        # Weighted sum of the two losses
        loss = cs_weight * feats_loss + ce_weight * label_loss

        loss.backward()
        optimizer.step()
        wandb.log({"batch_train_loss": loss.item(), "batch_train": distill_knowledge_ce_csl.train_batch_counter})
        distill_knowledge_ce_csl.train_batch_counter += 1
        running_loss += loss.item()

    return running_loss / len(train_loader)

def distill_knowledge_ce_mse(teacher, student, optimizer, train_loader, device, ce_weight, mse_weight):
    student.train()
    teacher.eval()

    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
        with torch.no_grad():
            _, teacher_feats = teacher(inputs)

        # Forward pass with the student model
        student_logits, student_feats = student(inputs)

        # Calculate the MSE loss. 
        mse_loss = nn.MSELoss()
        feats_loss = mse_loss(student_feats, teacher_feats)

        # Calculate the true label loss
        ce_loss = nn.CrossEntropyLoss()
        label_loss = ce_loss(student_logits, labels)

        # Weighted sum of the two losses
        loss = mse_weight * feats_loss + ce_weight * label_loss

        loss.backward()
        optimizer.step()

        wandb.log({"batch_train_loss": loss.item(), "batch_train": distill_knowledge_ce_mse.train_batch_counter})
        distill_knowledge_ce_mse.train_batch_counter += 1
        running_loss += loss.item()

    return running_loss / len(train_loader)

# function to evaluate model for 1 epoch
def test(model, test_loader, device):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            batch_total = labels.size(0)
            batch_correct = (predicted == labels).sum().item()
            wandb.log({"batch_test_acc": 100 * batch_correct / batch_total, "batch_test": test.test_batch_counter})
            test.test_batch_counter += 1
            total += batch_total
            correct += batch_correct

    accuracy = 100 * correct / total
    return accuracy

def train():
    wandb.init(
            # set the wandb project where this run will be logged
            project="knowledge-distillation-experiments",
            
            # track hyperparameters and run metadata
            config=wandb.config
        )
    
    args = wandb.config

    if 'resnet' in args.student_mod and 'resnet' in args.teacher_mod:
        img_size = 224
    elif 'cnn' in args.student_mod and 'cnn' in args.teacher_mod:
        img_size = 32

    train_data, test_data = get_cifar10_datasets(img_size=img_size)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True, num_workers=args.train_num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers)

    # initialize student and teacher models
    start = perf_counter()
    if args.student_mod == 'cnn':
        student = CNNStudent(num_classes=10, dropout=args.dropout)
    elif args.student_mod == 'resnet18':
        student = Resnet18(pretrained=False)
    elif args.student_mod == 'resnet18_pt':
        student = Resnet18(pretrained=True)

    if args.teacher_mod == 'cnn_pt':
        teacher = torch.load('./cnn_teacher.pth')
    elif args.teacher_mod == 'cnn':
        teacher = CNNTeacher(num_classes=10, dropout=args.dropout)
    elif args.teacher_mod == 'resnet34':
        teacher = Resnet34(pretrained=False)
    elif args.teacher_mod == 'resnet34_pt':
        teacher = Resnet34(pretrained=True)
    elif args.teacher_mod == 'resnet50':
        teacher = Resnet50(pretrained=False)
    elif args.teacher_mod == 'resnet50_pt':
        teacher = Resnet50(pretrained=True)

    student = student.to(args.device)
    teacher = teacher.to(args.device)
    end = perf_counter()
    mod_init_time = end - start
    wandb.log({'mods_init_time': mod_init_time})

    # initialize optimizer
    start = perf_counter()
    if args.opt == 'adam':
        optimizer = optim.Adam(student.parameters(), lr=args.lr)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(student.parameters(), lr=args.lr)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(student.parameters(), lr=args.lr)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(student.parameters(), lr=args.lr)
    elif args.opt == 'adadelta':
        optimizer = optim.Adadelta(student.parameters(), lr=args.lr)
    end = perf_counter()
    opt_init_time = end - start
    wandb.log({'opt_init_time': opt_init_time})

    wandb.define_metric("batch_train")
    wandb.define_metric("batch_test")
    wandb.define_metric("epoch")

    wandb.define_metric("batch_train_loss", step_metric="batch_train", summary='last')
    wandb.define_metric("epoch_train_loss", step_metric="epoch", summary='last')
    wandb.define_metric("batch_test_acc", step_metric="batch_test", summary='last')
    wandb.define_metric("epoch_test_acc", step_metric="epoch", summary='last')        

    # train and eval mod
    train_losses = []
    test_accs = []
    distill_knowledge_ce_stl.train_batch_counter = 0
    distill_knowledge_ce_csl.train_batch_counter = 0
    distill_knowledge_ce_mse.train_batch_counter = 0
    test.test_batch_counter = 0
    for i in range(args.epochs):
        if args.loss == 'ce-stl':
            train_loss = distill_knowledge_ce_stl(student=student, teacher=teacher, optimizer=optimizer, train_loader=train_loader, device=args.device, ce_weight=args.ce_w, st_weight=args.st_w, temperature=args.temp)
        elif args.loss == 'ce-csl':
            train_loss = distill_knowledge_ce_csl(teacher=teacher, student=student, optimizer=optimizer, train_loader=train_loader, device=args.device, ce_weight=args.ce_w, cs_weight=args.cs_w)
        elif args.loss == 'ce-mse':
            train_loss = distill_knowledge_ce_mse(teacher=teacher, student=student, optimizer=optimizer, train_loader=train_loader, device=args.device, ce_weight=args.ce_w, mse_weight=args.mse_w)
        test_acc = test(model=student, test_loader=test_loader, device=args.device)
        train_losses.append(train_loss)
        test_accs.append(test_acc)
        wandb.log({"epoch_train_loss": train_loss, "epoch": i})
        wandb.log({"epoch_test_acc": test_acc, "epoch": i})
    
    final_test_acc = sum(test_accs) / len(test_accs)
    wandb.log({'final_test_acc': final_test_acc})


In [2]:
sweep_configuration = {
    "method": "random",
    "metric": {"goal": "maximize", "name": "final_test_acc"},
    "parameters": {
        "epochs": {'values': [1, 10, 20]},
        'train_num_workers': {'values': [2, 4]},
        'test_num_workers': {'values': [2, 4]},
        'train_batch_size': {'values': [128, 256]},
        'test_batch_size': {'values': [128, 256]},
        'lr': {'distribution': 'uniform', "max": 1, "min": 0.0001},
        'dropout': {'distribution': 'uniform', "max": 1, "min": 0.001},
        'loss': {'values': ['ce-stl']},
        'opt': {'values': ['adam', 'sgd', 'adagrad', 'adadelta', 'rmsprop']},
        'student_mod': {'values': ['cnn']},
        'teacher_mod': {'values': ['cnn', 'cnn_pt']},
        'ce_w': {'distribution': 'uniform', "max": 1, "min": 0},
        #'cs_w': {'distribution': 'uniform', "max": 1, "min": 0},
        'st_w': {'distribution': 'uniform', "max": 1, "min": 0},
        #'mse_w': {'distribution': 'uniform', "max": 1, "min": 0},
        'temp': {'distribution': 'uniform', "max": 1, "min": 0},
        'device': {'value': 'cuda'}
    },
}
from pprint import pprint
pprint(sweep_configuration)

sweep_id = wandb.sweep(sweep=sweep_configuration, project="knowledge-distillation-experiments")

{'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'final_test_acc'},
 'parameters': {'ce_w': {'distribution': 'uniform', 'max': 1, 'min': 0},
                'device': {'value': 'cuda'},
                'dropout': {'distribution': 'uniform', 'max': 1, 'min': 0.001},
                'epochs': {'values': [1, 10, 20]},
                'loss': {'values': ['ce-stl']},
                'lr': {'distribution': 'uniform', 'max': 1, 'min': 0.0001},
                'opt': {'values': ['adam',
                                   'sgd',
                                   'adagrad',
                                   'adadelta',
                                   'rmsprop']},
                'st_w': {'distribution': 'uniform', 'max': 1, 'min': 0},
                'student_mod': {'values': ['cnn']},
                'teacher_mod': {'values': ['cnn', 'cnn_pt']},
                'temp': {'distribution': 'uniform', 'max': 1, 'min': 0},
                'test_batch_size': {'values': [128, 256]},
  

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 8pt4a1yv
Sweep URL: https://wandb.ai/amanichopra/knowledge-distillation-experiments/sweeps/8pt4a1yv


In [3]:
wandb.agent(sweep_id, function=train, count=25)

[34m[1mwandb[0m: Agent Starting Run: ezdls0oy with config:
[34m[1mwandb[0m: 	ce_w: 0.3579533881775676
[34m[1mwandb[0m: 	device: cuda
[34m[1mwandb[0m: 	dropout: 0.29087420236988465
[34m[1mwandb[0m: 	epochs: 20
[34m[1mwandb[0m: 	loss: ce-stl
[34m[1mwandb[0m: 	lr: 0.6636154488696536
[34m[1mwandb[0m: 	opt: adam
[34m[1mwandb[0m: 	st_w: 0.4335573999910032
[34m[1mwandb[0m: 	student_mod: cnn
[34m[1mwandb[0m: 	teacher_mod: cnn_pt
[34m[1mwandb[0m: 	temp: 0.6676902240172179
[34m[1mwandb[0m: 	test_batch_size: 128
[34m[1mwandb[0m: 	test_num_workers: 4
[34m[1mwandb[0m: 	train_batch_size: 256
[34m[1mwandb[0m: 	train_num_workers: 4
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mamanichopra[0m. Use [1m`wandb login --relogin`[0m to force relogin


Files already downloaded and verified
Files already downloaded and verified
