In [1]:
################################################################################
################# Do not change the code in this cell ##########################
################################################################################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import tqdm


class Net(nn.Module):
    def __init__(self, hidden_dim):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 3, 1)
        self.conv2 = nn.Conv2d(hidden_dim, 64, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

train_set = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
test_set = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True)


In [None]:
import wandb
from wandb.keras import WandbCallback
import torch.onnx

class Config:
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.criterion = F.nll_loss
        self.train_dataloader= None
        self.test_dataloader = None
        self.log_freq = 10
        self.sweep = {  "name" : "wandb-tutorial",
                        "method" : "grid",
                        "parameters" : {
                            "learning_rate" : {
                                "values" : [0.01, 0.001]
                            },
                            "hidden_dim" :{
                                "values" : [32, 1024]
                            },
                            "momentum" :{
                                "values" : [0,0.9]
                            }
                        }
                      }
        

def grid_search(config):
    sweep_id = wandb.sweep(config.sweep, entity = "cagnur", project="wandb-tutorial")
    
    def train_for_five_epochs():
        wandb.init()
        # Training
        config.model = Net(wandb.config.hidden_dim).cuda()
        config.optimizer = torch.optim.SGD(config.model.parameters(), lr=wandb.config.learning_rate, momentum=wandb.config.momentum)
        wandb.watch(config.model, config.criterion, log = 'all', log_freq = config.log_freq)
        config.model.train()
        counter = 0
        for epoch in range(5):            
            for imgs, labels in tqdm.tqdm(config.train_dataloader):
                imgs, labels = imgs.cuda(), labels.cuda()
                out = config.model(imgs)
                loss = config.criterion(out, labels)
                config.optimizer.zero_grad()
                loss.backward()
                config.optimizer.step()
                counter += 1
                if counter % 5 == 0:
                    wandb.log({'Loss': loss}, step = counter)
        # Training is done
        # Test
        config.model.eval()
        correct = 0
        with torch.no_grad():
            for imgs, labels in tqdm.tqdm(config.test_dataloader):
                imgs, labels = imgs.cuda(), labels.cuda()
                out = config.model(imgs)
                predictions = out.argmax(dim=1, keepdim=True)  
                correct += predictions.eq(labels.view_as(predictions)).sum().item()
        accuracy = correct/len(config.test_dataloader.dataset)
        wandb.log({"Accuracy":accuracy} )
        # Test is done
        # Export the model   
        torch.onnx.export(config.model,         # model being run 
                         imgs,     # model input (or a tuple for multiple inputs) 
                         "model.onnx",     # where to save the model  
                         export_params=True # store the trained parameter weights inside the model file 
                         )
        wandb.save("model.onnx")
    wandb.agent(sweep_id, function=train_for_five_epochs)

In [None]:
wandb.login()
config_info = Config()
config_info.train_dataloader = train_loader
config_info.test_dataloader = test_loader
grid_search(config_info)