Assignment10

In [2]:
import torch
from torchvision import datasets, transforms
from torch import nn
import optuna
import numpy as np
from optuna.pruners import ThresholdPruner
import torch.nn.functional as F

In [3]:
criterion = nn.CrossEntropyLoss()
class CifarNet(nn.Module):
    def __init__(self):
        super(CifarNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 64, 5)
        self.conv4 = nn.Conv2d(64, 128, 3)
        self.conv5 = nn.Conv2d(128, 128, 3)
        self.conv6 = nn.Conv2d(128, 256, 3)
        self.conv7 = nn.Conv2d(256, 256, 3)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(36864, 512)
        self.fc2 = nn.Linear(512, 128)
        self.dropout = nn.Dropout(p=0.3)
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = self.flat(x)
        # print(x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(x)        
        x = self.fc3(x)
        return x

In [5]:
def train(trial,num_epochs,optimizer,training_generator,validation_generator,model,device):
    train_accuracy = []
    valid_accuracy = []
    for epoch in range(num_epochs):
        model.train()
        running_accuracy = []
        for local_batch, local_labels in training_generator:
            local_batch, local_labels = local_batch.to(device), local_labels.to(device)
            optimizer.zero_grad()
            outputs = model(local_batch)
            loss = criterion(outputs, local_labels)
            loss.backward()
            optimizer.step()
            _, predicted = torch.max(outputs, 1)
            running_accuracy.append((predicted == local_labels).sum().item()/training_generator.batch_size)
        # Compute Training Accuracy
        train_accuracy.append(np.mean(running_accuracy))
        model.eval()
        running_accuracy = []
        with torch.no_grad():
            for local_batch, local_labels in validation_generator:
                local_batch, local_labels = local_batch.to(device), local_labels.to(device)
                outputs = model(local_batch)
                _, predicted = torch.max(outputs, 1)
                running_accuracy.append((predicted == local_labels).sum().item()/training_generator.batch_size)
        valid_accuracy.append(np.mean(running_accuracy))
        print('[{:d}] | train accuracy {:5.2f} | validation accuracy {:5.2f}'.format(epoch + 1,train_accuracy[-1]*100,valid_accuracy[-1]*100))
        # Study Parameters
        if len(valid_accuracy) > 4:
            trial.report(np.var(valid_accuracy[-5:]), epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
    return valid_accuracy[-1]

In [6]:
data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

train_ds = datasets.CIFAR10(root ='data/',
                            train=True,
                            download=True,
                            transform=data_transform)

val_ds = datasets.CIFAR10(root ='data/',
                          train=False,
                          download=True,
                          transform=data_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


In [10]:
def objective(trial):
    device = torch.device("cuda:0")
    learning_rate = trial.suggest_uniform("lr", 0.0001, 0.005)
    batch_size = trial.suggest_categorical('batch_size',[16, 32, 64])
    num_epochs = trial.suggest_categorical('num_epoch',[30,50,70])
    training_generator = torch.utils.data.DataLoader(train_ds,
                                             batch_size=batch_size,
                                              shuffle=True)
    validation_generator = torch.utils.data.DataLoader(val_ds,
                                             batch_size=batch_size,
                                              shuffle=True)
    model = CifarNet()
    model.to(device)
    optimizer = torch.optim.RMSprop(model.parameters(),lr=learning_rate)
    print(f'lr {learning_rate} \t batch size {batch_size} \t epochs {num_epochs}')
    intermediate_value=train(trial,
                             num_epochs,
                             optimizer,
                             training_generator,
                             validation_generator,
                             model,
                             device)

    return intermediate_value

In [12]:
study = optuna.create_study(pruner=ThresholdPruner(lower=0.001),study_name='StudyCifar10.v1')
study.optimize(objective, n_trials=5)

[32m[I 2021-02-09 17:23:04,893][0m A new study created in memory with name: StudyCifar10.v1[0m


lr 0.003655516965340838 	 batch size 16 	 epochs 30
[1] | train accuracy 10.01 | validation accuracy 10.00
[2] | train accuracy  9.83 | validation accuracy 10.00
[3] | train accuracy  9.94 | validation accuracy 10.00
[4] | train accuracy 10.14 | validation accuracy 10.00


[32m[I 2021-02-09 17:27:58,162][0m Trial 0 pruned. [0m


[5] | train accuracy  9.69 | validation accuracy 10.00
lr 0.0015876406731336905 	 batch size 16 	 epochs 70
[1] | train accuracy 23.79 | validation accuracy 26.76
[2] | train accuracy 33.63 | validation accuracy 37.54
[3] | train accuracy 39.27 | validation accuracy 42.24
[4] | train accuracy 42.55 | validation accuracy 43.78
[5] | train accuracy 44.27 | validation accuracy 46.25


[32m[I 2021-02-09 17:33:51,743][0m Trial 1 pruned. [0m


[6] | train accuracy 45.62 | validation accuracy 44.66
lr 0.0025841408896875133 	 batch size 16 	 epochs 50
[1] | train accuracy 10.00 | validation accuracy 10.00
[2] | train accuracy  9.79 | validation accuracy 10.00
[3] | train accuracy  9.95 | validation accuracy 10.00
[4] | train accuracy  9.88 | validation accuracy 10.00


[32m[I 2021-02-09 17:38:44,471][0m Trial 2 pruned. [0m


[5] | train accuracy  9.91 | validation accuracy 10.00
lr 0.0017544041134735207 	 batch size 16 	 epochs 50
[1] | train accuracy  9.84 | validation accuracy 10.00
[2] | train accuracy  9.91 | validation accuracy 10.00
[3] | train accuracy  9.85 | validation accuracy 10.00
[4] | train accuracy  9.95 | validation accuracy 10.00


[32m[I 2021-02-09 17:43:37,222][0m Trial 3 pruned. [0m


[5] | train accuracy  9.90 | validation accuracy 10.00
lr 0.003964402769159942 	 batch size 64 	 epochs 50
[1] | train accuracy  9.84 | validation accuracy  9.95
[2] | train accuracy  9.78 | validation accuracy  9.95
[3] | train accuracy  9.90 | validation accuracy  9.95
[4] | train accuracy  9.93 | validation accuracy  9.95


[32m[I 2021-02-09 17:46:13,595][0m Trial 4 pruned. [0m


[5] | train accuracy  9.76 | validation accuracy  9.95


In [None]:
study.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_batch_size,params_lr,params_num_epoch,state
0,0,0.000638372,2021-02-02 09:40:46.641518,2021-02-02 09:47:42.573412,0 days 00:06:55.931894,16,0.000822,30,PRUNED
1,1,0.0,2021-02-02 09:47:42.574333,2021-02-02 09:50:59.855900,0 days 00:03:17.281567,32,0.002795,50,PRUNED
2,2,0.0,2021-02-02 09:50:59.856847,2021-02-02 09:56:52.297398,0 days 00:05:52.440551,16,0.002225,30,PRUNED
3,3,0.0,2021-02-02 09:56:52.298821,2021-02-02 10:01:44.227261,0 days 00:04:51.928440,16,0.002739,70,PRUNED
4,4,0.0,2021-02-02 10:01:44.228838,2021-02-02 10:04:57.608039,0 days 00:03:13.379201,32,0.002807,50,PRUNED
5,5,0.0,2021-02-02 10:04:57.609395,2021-02-02 10:08:51.688044,0 days 00:03:54.078649,32,0.00261,70,PRUNED
6,6,0.0,2021-02-02 10:08:51.689012,2021-02-02 10:13:45.545729,0 days 00:04:53.856717,16,0.002021,70,PRUNED
7,7,0.0,2021-02-02 10:13:45.546596,2021-02-02 10:16:59.508462,0 days 00:03:13.961866,32,0.003743,50,PRUNED
8,8,0.0,2021-02-02 10:16:59.509280,2021-02-02 10:19:32.782728,0 days 00:02:33.273448,64,0.000628,30,PRUNED
9,9,0.0,2021-02-02 10:19:32.783921,2021-02-02 10:22:46.307482,0 days 00:03:13.523561,32,0.000925,30,PRUNED
