In [0]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils import data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
!pip install tensorboardX
from tensorboardX import SummaryWriter

# define pytorch device - useful for device-agnostic execution
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# define model parameters
NUM_EPOCHS = 90
BATCH_SIZE = 128
IMAGE_DIM = 32  # pixels
NUM_CLASSES = 10  # 10 classes for Cifar-10 dataset
DEVICE_IDS = [0]  # GPUs to use
OUTPUT_DIR = 'alexnet_data_out'
LOG_DIR = OUTPUT_DIR + '/tblogs'  # tensorboard logs
CHECKPOINT_DIR = OUTPUT_DIR + '/models'  # model checkpoints

# make checkpoint path directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

#Hack to debug in the middle of a sequential 
class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.size())
        return x

class AlexNet(nn.Module):
    """
    Neural network model consisting of layers propsed by AlexNet paper.
    """
    
    def __init__(self, num_classes=NUM_CLASSES):
        """
        Define and allocate layers for this neural net.
        Args:
            num_classes (int): number of classes to predict with this model
        """
        super().__init__()

        #main net
        self.norm = nn.LocalResponseNorm(size=3, alpha=0.00005, beta=0.75)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2, stride=1)
        self.conv2 = nn.Conv2d(64, 192, 5, padding=1)
        self.conv3 = nn.Conv2d(192, 384, 3, padding=1)
        self.conv4 = nn.Conv2d(384, 256, 3, padding=1)
        self.conv5 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu = nn.ReLU()

        self.mainConvLayers = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5]

        #Branch 1
        self.branch1 = nn.Sequential(
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(size=3, alpha=0.00005, beta=0.75),
            nn.Conv2d(64, 32, 3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.branch1fc = nn.Linear(1568, 10)

        #Branch 2
        self.branch2 = nn.Sequential(
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(size=3, alpha=0.00005, beta=0.75),
            nn.Conv2d(192, 32, 3, padding=1, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.branch2fc = nn.Linear(128, 10)

        #linear layers of main branch
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=(256 * 2*2), out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
        self.init_bias()  # initialize bias

    def init_bias(self):
        for layer in self.branch1:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)

        for layer in self.branch2:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)

        for layer in self.mainConvLayers:
            nn.init.normal_(layer.weight, mean=0, std=0.01)
            nn.init.constant_(layer.bias, 0)

        # original paper = 1 for Conv2d layers 2nd, 4th, and 5th conv layers
        nn.init.constant_(self.conv2.bias, 1)
        nn.init.constant_(self.conv4.bias, 1)
        nn.init.constant_(self.conv5.bias, 1)

    def forward(self, x):
        """
        Pass the input through the net.
        Args:
            x (Tensor): input tensor
        Returns:
            output (Tensor): output tensor
        """
        x = self.conv1(x)

        #BRANCH 1: 2 3x3 conv and one FC layer
        x1 = self.branch1(x)
        x1 = x1.view(-1, 1568)
        x1 = self.branch1fc(x1)

        x = self.relu(x)
        x = self.norm(x)
        x = self.pool(x)
        x = self.conv2(x)

        #BRANCH 2: 1 3x3 conv and one FC layer
        x2 = self.branch2(x)
        x2 = x2.view(-1, 128)
        x2 = self.branch2fc(x2)

        x = self.relu(x)
        x = self.norm(x)
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.pool(x)
        x = x.view(x.size(0), 256 * 2 * 2)  # reduce the dimensions for linear layer input
        x = self.classifier(x)
        return x1, x2, x


if __name__ == '__main__':
    # print the seed value
    seed = torch.initial_seed()
    print('Used seed : {}'.format(seed))

    tbwriter = SummaryWriter(log_dir=LOG_DIR)
    print('TensorboardX summary writer created')

    # create model
    alexnet = AlexNet(num_classes=NUM_CLASSES).to(device)
    # train on multiple GPUs
    alexnet = torch.nn.parallel.DataParallel(alexnet, device_ids=DEVICE_IDS)
    print(alexnet)
    print('AlexNet created')

    transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            #transforms.RandomRotation(15),
            transforms.RandomResizedCrop(IMAGE_DIM, scale=(0.9, 1.0), ratio=(0.9, 1.1)),
            transforms.CenterCrop(IMAGE_DIM),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    
    dataset = datasets.CIFAR10(root='./data', train=True, transform=transform,
                               download=True)
    
    testtransform = transforms.Compose([
        transforms.CenterCrop(IMAGE_DIM),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    testdata = datasets.CIFAR10(root='./data', train=False, transform=testtransform, download=True)
    print('Dataset created')
    
    dataloader = data.DataLoader(
        dataset,
        shuffle=True,
        pin_memory=True,
        num_workers=8,
        drop_last=True,
        batch_size=BATCH_SIZE)
    print('Dataloader created')

    optimizer = optim.Adam(params=alexnet.parameters(), lr=0.0001)
    print('Optimizer created')

    # multiply LR by 1 / 10 after every 30 epochs
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    print('LR Scheduler created')

    # start training!!
    print('Starting training...')
    alexnet.train()
    total_steps = 1
    end = False
    for epoch in range(NUM_EPOCHS):
        for imgs, classes in dataloader:
            imgs, classes = imgs.to(device), classes.to(device)
            optimizer.zero_grad()
            # calculate the loss
            output = alexnet(imgs)[-1]
            loss = F.cross_entropy(output, classes)

            # update the parameters
            loss.backward()
            optimizer.step()

            # log the information and add to tensorboard
            if total_steps % 10 == 0:
                #with torch.no_grad():
                _, preds = torch.max(output, 1)
                accuracy = torch.sum(preds == classes)

                print('Epoch: {} \tStep: {} \tLoss: {:.4f} \tAcc: {}'
                    .format(epoch + 1, total_steps, loss.item(), accuracy.item()))

            if total_steps % 300 == 0:

                #~~~~~~~VALIDATION~~~~~~~~~
                valdataloader = data.DataLoader(
                    testdata,
                    shuffle=True,
                    pin_memory=True,
                    num_workers=8,
                    drop_last=True,
                    batch_size=128)
                correct_count = 0
                total_count = 0
                alexnet.eval()
                for images, labels in valdataloader:
                    images, labels = images.to(device), labels.to(device)
                    with torch.no_grad(): #no gradient descent!
                        logps = alexnet(images)[-1]
                    
                    for i in range(BATCH_SIZE):
                        ps = torch.exp(logps)
                        prob = list(ps.cpu().numpy()[i])
                        pred_label = prob.index(max(prob))
                        true_label = labels.cpu().numpy()[i]
                        if(true_label == pred_label):
                            correct_count += 1
                        total_count += 1
                print("Number Of Images Tested =", total_count)
                print("\nModel Accuracy =", (correct_count/total_count))
                if correct_count/total_count > 0.83:
                    end = True
                alexnet.train()
            if end:
                break

            total_steps += 1
        if end:
            break
        lr_scheduler.step()



Used seed : 13681243636625788801
TensorboardX summary writer created
DataParallel(
  (module): AlexNet(
    (norm): LocalResponseNorm(3, alpha=5e-05, beta=0.75, k=1.0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv2): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu): ReLU()
    (branch1): Sequential(
      (0): ReLU()
      (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): LocalResponseNorm(3, alpha=5e-05, beta=0.75, k=1.0)
      (3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): Conv2d(32, 32, kernel_siz

# Testing

In [0]:
testingData = datasets.CIFAR10(root='./data', train=False, transform=testtransform,
                               download=True)
dataloader = data.DataLoader(
    testingData,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True,
    batch_size=BATCH_SIZE)
correct_count = 0
total_count = 0
alexnet.eval()
for imgs, classes in dataloader:
    imgs, classes = imgs.to(device), classes.to(device)

    with torch.no_grad(): #no gradient descent!
        logps = alexnet(imgs)[-1]
    for i in range(BATCH_SIZE):
        ps = torch.exp(logps)
        prob = list(ps.cpu().numpy()[i])
        pred_labels = prob.index(max(prob))
        true_labels = classes.cpu().numpy()[i]
        if(true_labels == pred_labels):
            correct_count += 1
        total_count += 1

    if total_count % 10 == 0:
        print(f"total: {total_count}, correct: {correct_count}")
print("Number Of Images Tested =", total_count)
print("\nModel Accuracy =", (correct_count/total_count))

# Look at the weights

In [0]:
with torch.no_grad():
    # print and save the grad of the parameters
    # also print and save parameter values
    print('*' * 10)
    for name, parameter in alexnet.named_parameters():
        if parameter.grad is not None:
            avg_grad = torch.mean(parameter.grad)
            print('\t{} - grad_avg: {}'.format(name, avg_grad))
            tbwriter.add_scalar('grad_avg/{}'.format(name), avg_grad.item(), total_steps)
            tbwriter.add_histogram('grad/{}'.format(name),
                    parameter.grad.cpu().numpy(), total_steps)
        if parameter.data is not None:
            avg_weight = torch.mean(parameter.data)
            print('\t{} - param_avg: {}'.format(name, avg_weight))
            tbwriter.add_histogram('weight/{}'.format(name),
                    parameter.data.cpu().numpy(), total_steps)
            tbwriter.add_scalar('weight_avg/{}'.format(name), avg_weight.item(), total_steps)

# Testing to confirm branches did NOT learn

In [0]:
testingData = datasets.CIFAR10(root='./data', train=False, transform=testtransform,
                               download=True)
dataloader = data.DataLoader(
    testingData,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True,
    batch_size=BATCH_SIZE)
correct_count = 0
total_count = 0
alexnet.eval()
for imgs, classes in dataloader:
    imgs, classes = imgs.to(device), classes.to(device)

    with torch.no_grad(): #no gradient descent!
        logps = alexnet(imgs)[0]
    for i in range(BATCH_SIZE):
        ps = torch.exp(logps)
        prob = list(ps.cpu().numpy()[i])
        pred_labels = prob.index(max(prob))
        true_labels = classes.cpu().numpy()[i]
        if(true_labels == pred_labels):
            correct_count += 1
        total_count += 1

print("\nBranch1 Accuracy =", (correct_count/total_count))

testingData = datasets.CIFAR10(root='./data', train=False, transform=testtransform,
                               download=True)
dataloader = data.DataLoader(
    testingData,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True,
    batch_size=BATCH_SIZE)
correct_count = 0
total_count = 0
alexnet.eval()
for imgs, classes in dataloader:
    imgs, classes = imgs.to(device), classes.to(device)

    with torch.no_grad(): #no gradient descent!
        logps = alexnet(imgs)[1]
    for i in range(BATCH_SIZE):
        ps = torch.exp(logps)
        prob = list(ps.cpu().numpy()[i])
        pred_labels = prob.index(max(prob))
        true_labels = classes.cpu().numpy()[i]
        if(true_labels == pred_labels):
            correct_count += 1
        total_count += 1

print("\nBranch2 Accuracy =", (correct_count/total_count))

# Train now with branches

In [0]:
NUM_EPOCHS = 70

transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        #transforms.RandomRotation(15),
        transforms.RandomResizedCrop(IMAGE_DIM, scale=(0.9, 1.0), ratio=(0.9, 1.1)),
        transforms.CenterCrop(IMAGE_DIM),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

dataset = datasets.CIFAR10(root='./data', train=True, transform=transform,
                            download=True)

testtransform = transforms.Compose([
    transforms.CenterCrop(IMAGE_DIM),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
testdata = datasets.CIFAR10(root='./data', train=False, transform=testtransform, download=True)
print('Dataset created')

dataloader = data.DataLoader(
    dataset,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True,
    batch_size=BATCH_SIZE)
print('Dataloader created')

optimizer = optim.Adam(params=alexnet.parameters(), lr=0.0001)
print('Optimizer created')

# multiply LR by 1 / 10 after every 30 epochs
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
print('LR Scheduler created')

# start training!!
print('Starting training...')
alexnet.train()
total_steps = 1
end = False
for epoch in range(NUM_EPOCHS):
    for imgs, classes in dataloader:
        imgs, classes = imgs.to(device), classes.to(device)
        optimizer.zero_grad()
        # calculate the loss
        outputb1, outputb2, outputmain = alexnet(imgs)
        lossb1 = F.cross_entropy(outputb1, classes)
        lossb2 = F.cross_entropy(outputb2, classes)
        lossmain = F.cross_entropy(outputmain, classes)
        loss = lossb1 + 0.6*lossb2 + 0.3*lossmain

        # update the parameters
        loss.backward()
        optimizer.step()

        # log the information and add to tensorboard
        if total_steps % 10 == 0:
            #with torch.no_grad():
            _, preds = torch.max(outputmain, 1)
            accuracy = torch.sum(preds == classes)

            print('Epoch: {} \tStep: {} \tLoss: {:.4f} \tAcc: {}'
                .format(epoch + 1, total_steps, loss.item(), accuracy.item()))
            print(f'Branch 1: loss {lossb1}')
            print(f'Branch 2: loss {lossb2}')
            print(f'Main: loss {lossmain}')

        if total_steps % 300 == 0:

            #~~~~~~~VALIDATION~~~~~~~~~
            valdataloader = data.DataLoader(
                testdata,
                shuffle=True,
                pin_memory=True,
                num_workers=8,
                drop_last=True,
                batch_size=128)
            b1_correct = 0
            b2_correct = 0
            main_correct = 0
            total_count = 0
            alexnet.eval()
            for image, label in valdataloader:
                image, label = image.to(device), label.to(device)
                with torch.no_grad(): #no gradient descent!
                    b1, b2, main = alexnet(image)
                
                b1ps = torch.exp(b1)
                b2ps = torch.exp(b2)
                mainps = torch.exp(main)
                for i in range(BATCH_SIZE):
                    b1prob = list(b1ps.cpu().numpy()[i])
                    b2prob = list(b2ps.cpu().numpy()[i])
                    mainprob = list(main.cpu().numpy()[i])


                    b1label = b1prob.index(max(b1prob))
                    b2label = b2prob.index(max(b2prob))
                    mainlabel = mainprob.index(max(mainprob))

                    true_label = label.cpu().numpy()[i]

                    if(true_label == b1label):
                        b1_correct += 1

                    if(true_label == b2label):
                        b2_correct += 1

                    if(true_label == mainlabel):
                        main_correct += 1

                    total_count += 1
                    if(total_count > 1500):
                        break
            print("Number Of Images Tested =", total_count*128)
            print(f"b1 acc: {b1_correct/total_count}, b2 acc: {b2_correct/total_count}, main acc: {main_correct/total_count}")
            alexnet.train()
        if end:
            break

        total_steps += 1
    if end:
        break
    lr_scheduler.step()

# Save model

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
        # save checkpoints
        checkpoint_path = os.path.join("/content/drive/My Drive/Colab Notebooks/models", 'Branchy-alexnet_states.pkl'.format(epoch + 1))
        state = {
            'epoch': epoch,
            'total_steps': total_steps,
            'optimizer': optimizer.state_dict(),
            'model': alexnet.state_dict(),
            'seed': seed,
        }
        torch.save(state, checkpoint_path)

# Testing all branches 

In [0]:
BATCH_SIZE = 1

testingData = datasets.CIFAR10(root='./data', train=False, transform=testtransform,
                               download=True)
dataloader = data.DataLoader(
    testingData,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True,
    batch_size=BATCH_SIZE)

b1_correct = 0
b2_correct = 0
main_correct = 0
total_count = 0
total_correct = 0
b1exit = 0
b2exit = 0
mainexit = 0
alexnet.eval()
for image, label in dataloader:
    image, label = image.to(device), label.to(device)
    with torch.no_grad(): #no gradients
        b1, b2, main = alexnet(image)
    
    b1ps = torch.exp(b1)
    b2ps = torch.exp(b2)
    mainps = torch.exp(main)
    for i in range(BATCH_SIZE):
        b1prob = list(b1ps.cpu().numpy()[i])
        b2prob = list(b2ps.cpu().numpy()[i])
        mainprob = list(main.cpu().numpy()[i])

        true_label = label.cpu().numpy()[i]
        
        labelvector = [0]*10
        labelvector[true_label-1] = 1
        print(torch.Tensor(labelvector).size())
        b1loss = F.binary_cross_entropy(b1, torch.Tensor(labelvector).cuda())
        b2loss = F.binary_cross_entropy(b2, torch.Tensor(labelvector).cuda())
        mainloss = F.binary_cross_entropy(main, torch.Tensor(labelvector).cuda())

        b1label = b1prob.index(max(b1prob))
        b2label = b2prob.index(max(b2prob))
        mainlabel = mainprob.index(max(mainprob))


        if(true_label == b1label):
            b1_correct += 1

        if(true_label == b2label):
            b2_correct += 1

        if(true_label == mainlabel):
            main_correct += 1

        if b1loss < 0.1:
            b1exit +=1
            if(true_label == b1label):
                total_correct += 1
        elif b2loss < 0.5:
            b2exit +=1
            if(true_label == b2label):
                total_correct += 1
        else:
            mainexit +=1
            if(true_label == mainlabel):
                total_correct += 1

        total_count += 1
print("Number Of Images Tested =", total_count)
print("Overall network accuracy with exits at branches: ", correct_count/total_count)
print(f"b1 acc: {b1_correct/total_count}, b2 acc: {b2_correct/total_count}, main acc: {main_correct/total_count}")
print(f'b1 exit amt: {b1exit}, b2 exit amt: {b2exit}, main exit amt: {mainexit}')