# Deep learning single path

## Data setup

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_LAUNCH_BLOCKING"]="1"
os.environ["CUDA_VISIBLE_DEVICES"]= f"{1}"

from singlePathModel import Dict, to_cuda, is_cuda, create_sub_blocks, SuperArchitectureNoLinear, SuperArchitecture
import logging
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR100, CIFAR10
import torch.nn as nn
import torch.optim as optim

from math import floor
import time
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# Logging
filename = "SinglePathCIFAR100"
logging.basicConfig(filename = f"logs/{filename}", filemode = 'w', format='%(asctime)s - %(message)s', level=logging.DEBUG)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Cuda

In [2]:
if is_cuda():
    print("cuda")
    device = torch.device("cuda")
    print(torch.cuda.current_device())
else:
    print("cpu")
    device = torch.device("cpu")

cuda
0


## Data

In [3]:
load_mnist = False
load_cifar = False
load_cifar100 = True

if load_mnist:
    transform = transforms.Compose([transforms.ToTensor()])
    train_data_set = MNIST("./temp/", train=True, download=True, transform=transform)
    test_data_set = MNIST("./temp/", train=False, download=True, transform=transform)

if load_cifar:
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
        ]
    )
    train_data_set = CIFAR10("./temp/", train=True, download=True, transform=transform)
    test_data_set = CIFAR10("./temp/", train=False, download=True, transform=transform)
if load_cifar100:
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
        ]
    )
    train_data_set = CIFAR100("./temp/", train=True, download=True, transform=transform)
    test_data_set = CIFAR100("./temp/", train=False, download=True, transform=transform)
trainloader = DataLoader(train_data_set, batch_size=16, shuffle=True, num_workers=2)
testloader = DataLoader(test_data_set, batch_size=16, shuffle=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


## Single path model

In [4]:
hyperparam = {
    'MNIST' : {
        'num_epochs' : 100,
        'lr' : 0.001,
        'momentum': 0.8
    },
    'CIFAR' : {
        'num_epochs' : 50,
        'lr' : 0.01,
        'momentum': 0.8
    }
}
if load_mnist:
    chosen_hyp = hyperparam['MNIST']
if load_cifar or load_cifar100:
    chosen_hyp = hyperparam['CIFAR']

In [5]:
# Data size
x, y = next(iter(trainloader))
batch_size, channels, img_x, img_y = x.shape
data_points = len(trainloader)*batch_size
n_classes = len(set(train_data_set.targets))
# Setup the different layers

# Linear layers
first_linear_params = [
    Dict(in_features=img_x*img_y, out_features=256),
    Dict(in_features=256, out_features=64),
]

# Conv layers
default_param = dict(expand_ratio_list=[3,5,7,9], kernel_size_list=[3,5,7,9])
dims = [1*channels, 4*channels, 8*channels, 16*channels, 32*channels, 64*channels, 126*channels, 250*channels]
middle_conv_params = []
for dim1, dim2 in zip(dims[:-1], dims[1:]):
    middle_conv_params.append(Dict(in_channels=dim1, out_channels=dim2, **default_param, blocks=1, stride=1))
    middle_conv_params.append(Dict(in_channels=dim2, out_channels=dim2, **default_param, blocks=1, stride=1))

# Last conv
last_conv_param = Dict(in_channels=dims[-1], out_channels=dims[-1], kernel_size=1, stride=1)

# Output linear layer
last_linear_params = [
    dict(in_features=dims[-1], out_features=n_classes)
]

In [6]:
net = SuperArchitectureNoLinear(
    #first_linear_params=first_linear_params,
    middle_conv_params=middle_conv_params,
    last_conv_param=last_conv_param,
    last_linear_params=last_linear_params
).cuda()
print(net)

SuperArchitectureNoLinear(
  (middle): Sequential(
    (0): SuperConvBlock(
      (module): Sequential(
        (0): Sequential(
          (0): SuperConv2d()
          (1): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6()
        )
        (1): Sequential(
          (0): SuperConv2d()
          (1): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6()
        )
        (2): Sequential(
          (0): Conv2d(27, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (shortcut): Sequential(
        (0): Conv2d(3, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): SuperConvBlock(
      (module): Sequential(
        (0): Sequential(
          (0): SuperConv2d()

In [7]:
def train_thresholds(net, val_images, val_labels, threshold_optimizer, criterion):
    """ Trains the threshold parameter"""
    # Zero_grad
    threshold_optimizer.zero_grad()

    # Predict and eval
    val_logits = net(val_images)
    val_loss = criterion(val_logits, val_labels)

    # Optimize
    val_loss.backward()
    threshold_optimizer.step()
    return val_logits, val_loss

def train_weights(net, train_images, train_targets, weight_optimizer, criterion):
    """ Trains the weights"""
    # Zero_grad
    weight_optimizer.zero_grad()
    
    # Predict and eval
    train_logits = net(train_images)
    train_loss = criterion(train_logits, train_targets)

    # Optimize
    train_loss.backward()
    weight_optimizer.step()
    return train_logits, train_loss

# Train model

In [8]:
# Setup
# Init
losses = []
test_acc = []
train_acc = []

# Init network
net = SuperArchitectureNoLinear(
    #first_linear_params=first_linear_params,
    middle_conv_params=middle_conv_params,
    last_conv_param=last_conv_param,
    last_linear_params=last_linear_params
).cuda()

# Eval model doing training
eval_model = True

# Save time
train_time = []

In [9]:
# Optimizer
criterion = nn.CrossEntropyLoss()
weight_optimizer = torch.optim.SGD(
    params=net.weights(),
    lr=chosen_hyp['lr'],
    momentum=chosen_hyp['momentum']
)

threshold_optimizer = torch.optim.Adam(
    params=net.thresholds(),
    lr=0.01,
    weight_decay=0.001
)

In [None]:
print("Started training")
logging.debug("Started training")
for epoch in range(chosen_hyp['num_epochs']):
    start_time = time.perf_counter()
    # Forward -> Backprob -> Update params
    
    ## Train
    cur_loss = 0
    net.train()
    for data in trainloader:
        # Get data
        x, y = data
        x_train, y_train = x[::2], y[::2]
        x_val, y_val = x[1::2], y[1::2]
        # Optimizer and batch
        train_logits, train_loss = train_weights(net, x_train.cuda(), y_train.cuda(), weight_optimizer, criterion)
        val_logits, val_loss = train_thresholds(net, x_val.cuda(), y_val.cuda(), threshold_optimizer, criterion)
        
        cur_loss += train_loss
    
    end_time = time.perf_counter()
    train_time.append(end_time - start_time)
    # Save losses
    losses.append(cur_loss / batch_size)
    
    if eval_model:
        ### Evaluate training
        net.eval()
        print("You can break now!", end='\r')
        logging.debug("You can break now!")
        train_preds, train_targs = [], []
        for data in trainloader:
            x_train, y_train = data
            output = net(x_train.cuda())

            preds = torch.max(output, 1)[1]

            train_targs += list(y_train.numpy())
            train_preds += list(preds.data.cpu().numpy())

        ### Evaluate validation
        test_preds, test_targs = [], []
        for data in testloader:
            x_test, y_test = data
            output = net(x_test.cuda())
            preds = torch.max(output, 1)[1]
            test_targs += list(y_test.numpy())
            test_preds += list(preds.data.cpu().numpy())


        train_acc_cur = accuracy_score(train_targs, train_preds)
        test_acc_cur = accuracy_score(test_targs, test_preds)

        train_acc.append(train_acc_cur)
        test_acc.append(test_acc_cur)
        
        logging.debug("Do not break!")
        print("Do not break!", end='\r')

        if epoch % 1 == 0:
            string = "Epoch %2i : Train Loss %f, Train acc %f, Test acc %f, Epoch train time %.2f min" % (
                    epoch+1, losses[-1], train_acc_cur, test_acc_cur, train_time[-1]/60)
            logging.debug(string)
            print(string)
    else:
        string = "Epoch %2i, Epoch train time %.2f min" % (epoch+1, train_time[-1]/60)
        logging.debug(string)
        print(string)

logging.debug(f"Total time: {sum(train_time)/(60*60)} hours")

Started training
Epoch  1 : Train Loss 787.227905, Train acc 0.163420, Test acc 0.163000, Epoch train time 21.82 min
Epoch  2 : Train Loss 665.342834, Train acc 0.270660, Test acc 0.263100, Epoch train time 21.78 min
Epoch  3 : Train Loss 580.044617, Train acc 0.359820, Test acc 0.343300, Epoch train time 21.57 min
Epoch  4 : Train Loss 513.720642, Train acc 0.430680, Test acc 0.398700, Epoch train time 22.62 min
Epoch  5 : Train Loss 458.821045, Train acc 0.484660, Test acc 0.455100, Epoch train time 21.78 min
Epoch  6 : Train Loss 420.137726, Train acc 0.532360, Test acc 0.481600, Epoch train time 21.70 min
Epoch  7 : Train Loss 383.627319, Train acc 0.560960, Test acc 0.504800, Epoch train time 21.70 min
Epoch  8 : Train Loss 352.930389, Train acc 0.600640, Test acc 0.531400, Epoch train time 21.72 min
Epoch  9 : Train Loss 327.077301, Train acc 0.635160, Test acc 0.561300, Epoch train time 22.00 min
Epoch 10 : Train Loss 305.315216, Train acc 0.657720, Test acc 0.579600, Epoch trai

In [None]:
torch.save(net, f"models/{filename}")

In [None]:
net.eval()
print(f'Training accuracy: {train_acc[-1]*100:.2f}%\nTest accuracy: {test_acc[-1]*100:.2f}%')

In [None]:
if eval_model:
    epoch = list(range(len(train_acc)))
    plt.figure()
    plt.plot(epoch, train_acc, 'r', epoch, test_acc, 'b')
    plt.legend(['Train Accucary','Test Accuracy'])
    plt.xlabel('Updates')
    plt.ylabel('Acc')
    plt.show()

In [8]:
sum([par.numel() for par in net.parameters()])

21486434