**Import CIFAR-10 dataset**

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

In [2]:
def create_CIFAR_data():

    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    return trainset, trainloader, testset, testloader

**Process CIFAR-10 dataset**

In [3]:
def get_binary_label(targets, index):
    ''' Cats have index 3, dogs have index 5 '''

    zeros = torch.zeros_like(targets)
    ones = torch.ones_like(targets)

    labels = torch.where(targets == index, ones, zeros)

    return labels

In [4]:
import random

def get_branch_indices(targets, classes):

  bg = []
  indices = list(range(len(targets)))
  for index, target in enumerate(targets):
    if target not in classes:
      bg.append(index)

  branch_one_bg = random.sample(bg, int(len(bg) / 2))
  branch_two_bg = [x for x in bg if x not in branch_one_bg]

  branch_one_idx = [x for x in indices if x not in branch_two_bg]
  branch_two_idx = [x for x in indices if x not in branch_one_bg]

  return torch.tensor(branch_one_idx), torch.tensor(branch_two_idx)

In [5]:
import random

def create_unbalanced_CIFAR10(trainset, class_sizes = [625,625,625,5000,625,5000,625,625,625,625]):

  labels = np.array(trainset.targets)
  classes, sizes = np.unique(labels, return_counts=True)
  print(sizes)

  imbalanced_indices = []

  for i in range(len(classes)):
    indices = list(np.where(labels == i)[0])
    class_size = class_sizes[i]
    imbalanced_indices.extend(random.sample(indices, class_size))

  
  trainset.targets = labels[imbalanced_indices]
  trainset.data = trainset.data[imbalanced_indices]
  classes, sizes = np.unique(trainset.targets, return_counts=True)
  print(sizes)

  return trainset

**Creating ResNet model**

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Type, Any, Callable, Union, List, Optional

In [7]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    '''
    Implementation is taken from the PyTorch GitHub repository
    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    '''

    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    '''
    Implementation is taken from the PyTorch GitHub repository
    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    '''    
    
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [8]:
class BasicBlock(nn.Module):
    '''
    Implementation is taken from the PyTorch GitHub repository
    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    '''
    
    expansion: int = 1

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(BasicBlock, self).__init__()
        
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [9]:
class Bottleneck(nn.Module):
    '''
    Implementation is taken from the PyTorch GitHub repository
    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    '''

    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(Bottleneck, self).__init__()
        
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        width = int(planes * (base_width / 64.)) * groups
        
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [10]:
class ResNetSplitShared(nn.Module):
    ''' https://stackoverflow.com/questions/66786787/pytorch-multiple-branches-of-a-model '''

    def __init__(self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 10, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(ResNetSplitShared, self).__init__()
        
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        
        self.groups = groups
        self.base_width = width_per_group
        
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        ##### SHARED LAYERS #####
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.layer1 = self._make_shared_layer(block, 64, layers[0])
        self.layer2 = self._make_shared_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        
        ##### BRANCH 1 LAYERS #####
        self.branch1_inplanes = 128
        self.branch1layer3 = self._make_branch1_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.branch1layer4 = self._make_branch1_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.branch1fc = nn.Linear(512 * block.expansion, num_classes)


        ##### BRANCH 2 LAYERS #####
        self.branch2_inplanes = 128
        self.branch2layer3 = self._make_branch2_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.branch2layer4 = self._make_branch2_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])      
        self.branch2fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]


    def _make_shared_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
        
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        
        if dilate:
            self.dilation *= stride
            stride = 1
        
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        
        self.inplanes = planes * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)


    def _make_branch1_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
        
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        
        if dilate:
            self.dilation *= stride
            stride = 1
        
        if stride != 1 or self.branch1_inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.branch1_inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.branch1_inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        
        self.branch1_inplanes = planes * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.branch1_inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)


    def _make_branch2_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
        
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        
        if dilate:
            self.dilation *= stride
            stride = 1
        
        if stride != 1 or self.branch2_inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.branch2_inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.branch2_inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        
        self.branch2_inplanes = planes * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.branch2_inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)


    def get_branch_params(self):

        self.shared_params = [
                        {'params': self.conv1.parameters()},
                        {'params': self.bn1.parameters()},
                         {'params': self.layer1.parameters()},
                         {'params': self.layer2.parameters()},
        ]
        self.branch1_params = [
                        {'params': self.branch1layer3.parameters()},
                        {'params': self.branch1layer4.parameters()},
                         {'params': self.branch1fc.parameters()},
        ]
        self.branch2_params = [
                        {'params': self.branch2layer3.parameters()},
                        {'params': self.branch2layer4.parameters()},
                         {'params': self.branch2fc.parameters()},
        ]

        return self.shared_params, self.branch1_params, self.branch2_params


    def _forward_shared_branch(self, x:Tensor) -> Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        
        return out


    def _forward_branch_1(self, shared_out: Tensor) -> Tensor:
        branch1_out = self.branch1layer3(shared_out)
        branch1_out = self.branch1layer4(branch1_out)
        branch1_out = self.avgpool(branch1_out)
        branch1_out = torch.flatten(branch1_out, 1)
        branch1_out = self.branch1fc(branch1_out)

        return branch1_out


    def _forward_branch_2(self, shared_out: Tensor) -> Tensor:
        branch2_out = self.branch2layer3(shared_out)
        branch2_out = self.branch2layer4(branch2_out)
        branch2_out = self.avgpool(branch2_out)
        branch2_out = torch.flatten(branch2_out, 1)
        branch2_out = self.branch2fc(branch2_out)

        return branch2_out
    
    
    def forward(self, x: Tensor) -> Tensor:
        shared = self._forward_shared_branch(x)
        branch_one_out = self._forward_branch_1(shared)
        branch_two_out = self._forward_branch_2(shared)

        return branch_one_out, branch_two_out


In [11]:
def ResNetSplit18Shared():
    return ResNetSplitShared(BasicBlock, [2,2,2,2], num_classes=2)

**Congestion Avoidance scheduler**


*   Based on accumulated gradients (multiplied by the lr in each epoch)
*   As is based on lr*grad is only truly suitable for standard SGD
* Gradients are reset after a congestion on one branch (for that branch only) -- This treats the new position as a new start point to accumulate gradients from



In [12]:
def congestion_avoid(model, optimizer, branch1_metric, branch2_metric, condition, branch_one_grads, branch_two_grads, min_epochs, mult):

    global epoch_count_one
    global epoch_count_two

    boolean_one = False
    boolean_two = False

    branch1_cond = (branch1_metric < condition * branch2_metric) and (epoch_count_two >= min_epochs)
    branch2_cond = (branch2_metric < condition * branch1_metric) and (epoch_count_one >= min_epochs)

    if branch1_cond:
        boolean_one = True
        print('Branch 1 condition has been met ..... {:.2f}%'.format(100.*condition))
        for name, value in model.named_parameters():
            with torch.no_grad():
                if name in branch_two_grads.keys():
                    value += mult * branch_two_grads[name]
        for name in branch_two_grads.keys():
            branch_two_grads[name] -= mult * branch_two_grads[name]
        epoch_count_two = 0

    elif branch2_cond:
        boolean_two = True
        print('Branch 2 condition has been met ..... {:.2f}%'.format(100.*condition))
        for name, value in model.named_parameters():
            with torch.no_grad():
                if name in branch_one_grads.keys():
                    value += mult * branch_one_grads[name]
        for name in branch_one_grads.keys():
            branch_one_grads[name] -= mult * branch_one_grads[name]
        epoch_count_one = 0
    
    else:
        print('No condition is met ..... {:.2f}%'.format(100.*condition))

    return optimizer, model, boolean_one, boolean_two, branch_one_grads, branch_two_grads

**Training the ResNet model**


*   Accumulate the gradient * lr
*   Reset the accumulated gradients to zero on a branch if in the previous epoch we had to roll it back



In [13]:
def train_congestion_avoider(trainloader, device, model, optimizer, branch_one_criterion, branch_two_criterion, branch_one_class, branch_two_class, boolean_one, boolean_two, branch_one_grads, branch_two_grads):

    global epoch_count_one
    global epoch_count_two

    ''' 
        model = The model to be trained
        shared_optim, branch1_optim, branch2_optim = the optimizers used to determine how network weights are updated in each section of the network (e.g. SGD)
        prior_shared_params, prior_branch1_params, prior_branch2_params = The network parameters from the previous epoch, used by 'congestion_scheduler' to roll back the weights by one epoch
        branch_x_criterion = The criterion used to define the loss function
        branch_classes = Must be a list of length 2. Defines the classes that each branch of the model is learning to classify
        epoch = The current epoch in training
     '''

    import copy

    model.train()
    branch_one_train_loss = 0
    branch_two_train_loss = 0
    branch_one_correct = 0
    branch_two_correct = 0
    branch_one_total = 0
    branch_two_total = 0
    branch_one_TP = 0
    branch_one_FP = 0
    branch_one_TN = 0
    branch_one_FN = 0
    branch_two_TP = 0
    branch_two_FP = 0
    branch_two_TN = 0
    branch_two_FN = 0
    branch_two_grads_tmp = {}
    start_time = time.time()
    
    #if (epoch % reset_epochs == 0) or boolean_two:
    if boolean_two:
        # SHOULD I RESET THE GRADIENTS HERE OR SHOULD IT ALWAYS BE A ROLLING SUM!!!!
        #branch_one_grads = {}
        epoch_count_one = 0
    #if (epoch % reset_epochs == 0) or boolean_one:
    if boolean_one:
        # SHOULD I RESET THE GRADIENTS HERE OR SHOULD IT ALWAYS BE A ROLLING SUM!!!!
        #branch_two_grads = {}
        epoch_count_two = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        branch_one_targets = get_binary_label(targets, index=branch_one_class)
        branch_two_targets = get_binary_label(targets, index=branch_two_class)
        
        branch_one_idx, branch_two_idx = get_branch_indices(targets, classes=[branch_one_class, branch_two_class])
        branch_one_inputs = torch.index_select(inputs, 0, branch_one_idx)
        branch_one_targets = torch.index_select(branch_one_targets, 0, branch_one_idx)
        branch_two_inputs = torch.index_select(inputs, 0, branch_two_idx)
        branch_two_targets = torch.index_select(branch_two_targets, 0, branch_two_idx)
        
        branch_one_inputs, branch_two_inputs, branch_one_targets, branch_two_targets = branch_one_inputs.to(device), branch_two_inputs.to(device), branch_one_targets.to(device), branch_two_targets.to(device)
        optimizer.zero_grad()
        
        branch_one_outputs, _ = model(branch_one_inputs)
        _, branch_two_outputs = model(branch_two_inputs)

        branch_one_loss = branch_one_criterion(branch_one_outputs, branch_one_targets)
        branch_two_loss = branch_two_criterion(branch_two_outputs, branch_two_targets)
        
        # Back-propagate the loss due to 'cats'
        branch_one_loss.backward(retain_graph=True)
        with torch.no_grad():
            for name, parameter in model.named_parameters():
                try:
                    branch_two_grads_tmp[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                    if name not in branch_one_grads.keys():
                        branch_one_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                    else:
                        branch_one_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                except:
                    pass

        branch_two_loss.backward(retain_graph=True)
        with torch.no_grad():
            for name, parameter in model.named_parameters():
                if parameter.grad is not None:
                    try:
                        if name not in branch_two_grads.keys():
                            if name in branch_two_grads_tmp.keys():
                                branch_two_grads[name] = (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                            else:
                                branch_two_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        else:
                            if name in branch_two_grads_tmp.keys():
                                branch_two_grads[name] += (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                            else:
                                branch_two_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                    except:
                        pass
        optimizer.zero_grad()

        total_loss = branch_one_loss + branch_two_loss
        total_loss.backward()
        optimizer.step()

        branch_one_train_loss += branch_one_loss.item()
        branch_two_train_loss += branch_two_loss.item()
        _, branch_one_predicted = branch_one_outputs.max(1)
        _, branch_two_predicted = branch_two_outputs.max(1)
        branch_one_total += branch_one_targets.size(0)
        branch_two_total += branch_two_targets.size(0)
        branch_one_correct += branch_one_predicted.eq(branch_one_targets).sum().item()
        branch_two_correct += branch_two_predicted.eq(branch_two_targets).sum().item()

        for target, pred in zip(branch_one_targets, branch_one_predicted):
          if target == 0:
            if pred == 0:
              branch_one_TN += 1
            else:
              branch_one_FP += 1
          elif target == 1:
            if pred == 1:
              branch_one_TP += 1
            else:
              branch_one_FN += 1
        
        for target, pred in zip(branch_two_targets, branch_two_predicted):
          if target == 0:
            if pred == 0:
              branch_two_TN += 1
            else:
              branch_two_FP += 1
          elif target == 1:
            if pred == 1:
              branch_two_TP += 1
            else:
              branch_two_FN += 1

    epoch_count_one += 1
    epoch_count_two += 1

    branch_one_acc = 100.*branch_one_correct/branch_one_total
    if branch_one_TP + branch_one_FP > 0:
      branch_one_precision = 100.*branch_one_TP/(branch_one_TP + branch_one_FP)
    else:
      branch_one_precision = 0
    if branch_one_TP + branch_one_FN > 0:
      branch_one_recall = 100.*branch_one_TP/(branch_one_TP + branch_one_FN)
    else:
      branch_one_recall = 0
    
    branch_two_acc = 100.*branch_two_correct/branch_two_total
    if branch_two_TP + branch_two_FP > 0:
      branch_two_precision = 100.*branch_two_TP/(branch_two_TP + branch_two_FP)
    else:
      branch_two_precision = 0
    if branch_two_TP + branch_two_FN > 0:
      branch_two_recall = 100.*branch_two_TP/(branch_two_TP + branch_two_FN)
    else:
      branch_two_recall = 0

    try:
      branch_one_F = 2 * branch_one_precision * branch_one_recall / (branch_one_precision + branch_one_recall)
    except:
      branch_one_F = 0
    try:
      branch_two_F = 2 * branch_two_precision * branch_two_recall / (branch_two_precision + branch_two_recall)
    except:
      branch_two_F = 0

    print("total train iters ", len(trainloader), '| time: %.3f sec Cat Loss: %.3f | Cat Acc: %.3f%% (%d/%d) | Dog Loss: %.3f | Dog Acc: %.3f%% (%d/%d)'
        % ((time.time()-start_time), branch_one_train_loss/(batch_idx+1), 
           branch_one_acc, branch_one_correct, branch_one_total, 
           branch_two_train_loss/(batch_idx+1), branch_two_acc, 
           branch_two_correct, branch_two_total))
    print('Cat P: : %.3f%% (%d/%d) | Dog P: %.3f%% (%d/%d)'% (branch_one_precision, branch_one_TP, branch_one_TP + branch_one_FP, branch_two_precision, branch_two_TP, branch_two_TP + branch_two_FP))
    print('Cat R: : %.3f%% (%d/%d) | Dog R: %.3f%% (%d/%d)'% (branch_one_recall, branch_one_TP, branch_one_TP + branch_one_FN, branch_two_recall, branch_two_TP, branch_two_TP + branch_two_FN))
    print('Cat F: : %.3f%%         | Dog F: %.3f%%'% (branch_one_F, branch_two_F))

    return branch_one_acc, branch_two_acc, branch_one_precision, branch_two_precision, branch_one_recall, branch_two_recall, branch_one_F, branch_two_F, branch_one_grads, branch_two_grads

In [None]:
def train_congestion_avoider_debug(trainloader, device, model, optimizer, branch_one_criterion, branch_two_criterion, branch_one_class, branch_two_class, boolean_one, boolean_two):

    global branch_one_grads
    global branch_two_grads
    global epoch_count_one
    global epoch_count_two

    ''' 
        model = The model to be trained
        shared_optim, branch1_optim, branch2_optim = the optimizers used to determine how network weights are updated in each section of the network (e.g. SGD)
        prior_shared_params, prior_branch1_params, prior_branch2_params = The network parameters from the previous epoch, used by 'congestion_scheduler' to roll back the weights by one epoch
        branch_x_criterion = The criterion used to define the loss function
        branch_classes = Must be a list of length 2. Defines the classes that each branch of the model is learning to classify
        epoch = The current epoch in training
     '''

    import copy

    model.train()
    branch_one_train_loss = 0
    branch_two_train_loss = 0
    branch_one_correct = 0
    branch_two_correct = 0
    branch_one_total = 0
    branch_two_total = 0
    branch_one_grads_tmp = {}
    branch_two_grads_tmp = {}
    total_grads = {}
    start_time = time.time()
    
    #if (epoch % reset_epochs == 0) or boolean_two:
    if boolean_two:
        # SHOULD I RESET THE GRADIENTS HERE OR SHOULD IT ALWAYS BE A ROLLING SUM!!!!
        branch_one_grads = {}
        epoch_count_one = 0
    #if (epoch % reset_epochs == 0) or boolean_one:
    if boolean_one:
        # SHOULD I RESET THE GRADIENTS HERE OR SHOULD IT ALWAYS BE A ROLLING SUM!!!!
        branch_two_grads = {}
        epoch_count_two = 0
    # The trainloader here needs to reference the imbalanced dataset (maybe only 2 classes)

    inputs, targets = next(iter(trainloader))
    inputs_2, targets_2 = next(iter(trainloader))
    inputs_3, targets_3 = next(iter(trainloader))
    inputs_4, targets_4 = next(iter(trainloader))
    inputs_5, targets_5 = next(iter(trainloader))
    inputs_6, targets_6 = next(iter(trainloader))
    inputs_7, targets_7 = next(iter(trainloader))
    inputs_8, targets_8 = next(iter(trainloader))
    inputs_9, targets_9 = next(iter(trainloader))
    inputs_10, targets_10 = next(iter(trainloader))

    for index, (input, target) in enumerate(zip([inputs,inputs_2,inputs_3,inputs_4,inputs_5,inputs_6,inputs_7,inputs_8,inputs_9,inputs_10], [targets,targets_2,targets_3,targets_4,targets_5,targets_6,targets_7,targets_8,targets_9,targets_10])):
        print('\nIMAGE ', index+1)
        branch_one_targets = get_binary_label(targets, index=branch_one_class)
        branch_two_targets = get_binary_label(targets, index=branch_two_class)
        inputs, branch_one_targets, branch_two_targets = inputs.to(device), branch_one_targets.to(device), branch_two_targets.to(device)
        optimizer.zero_grad()
        branch_one_outputs, branch_two_outputs = model(inputs)
        branch_one_loss = branch_one_criterion(branch_one_outputs, branch_one_targets)
        branch_two_loss = branch_two_criterion(branch_two_outputs, branch_two_targets)
        
        # Back-propagate the loss due to 'cats'
        branch_one_loss.backward(retain_graph=True)
        with torch.no_grad():
            for name, parameter in model.named_parameters():
                #if parameter.grad is not None:
                try:
                    branch_two_grads_tmp[name] = torch.mul(copy.deepcopy(parameter.grad), 1)
                    if name not in branch_one_grads.keys():
                        #if name == 'module.conv1.weight':
                        #    print('Branch one backward --> conv1 grad (NOT ADDING): ', torch.sum(parameter.grad))
                        branch_one_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        #branch_one_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        branch_two_grads_tmp[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                    else:
                        #if name == 'module.conv1.weight':
                        #    print('Branch one backward --> conv1 grad (ADDING): ', torch.sum(parameter.grad))
                        branch_one_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        #branch_one_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        branch_two_grads_tmp[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                except:
                    #print('ERROR! Parameter: ', name, ': ', parameter.grad)
                    pass
        #print('BRANCH ONE GRADS conv1: ', torch.sum(branch_one_grads['module.conv1.weight']))
        #model.zero_grad()

        branch_two_loss.backward(retain_graph=True)
        with torch.no_grad():
            for name, parameter in model.named_parameters():
                if parameter.grad is not None:
                    try:
                        if name not in branch_two_grads.keys():
                            if name in branch_two_grads_tmp.keys():
                                #if name == 'module.conv1.weight':
                                #    print('Branch two backward --> conv1 grad (NOT ADDING): ', torch.sum(parameter.grad- branch_two_grads_tmp[name]))
                                branch_two_grads[name] = (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                            else:
                                #if name == 'module.conv1.weight':
                                #    print('Branch two backward --> conv1 grad (NOT ADDING): ', torch.sum(parameter.grad))
                                branch_two_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                            #branch_two_grads[name] = (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                        else:
                            if name in branch_two_grads_tmp.keys():
                                #if name == 'module.conv1.weight':
                                #    print('Branch two backward --> conv1 grad (ADDING): ', torch.sum(parameter.grad - branch_two_grads_tmp[name]))
                                branch_two_grads[name] += (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                            else:
                                #if name == 'module.conv1.weight':
                                #    print('Branch two backward --> conv1 grad (ADDING): ', torch.sum(parameter.grad))
                                branch_two_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                            #branch_two_grads[name] += (torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr']) - branch_two_grads_tmp[name])
                    except:
                        pass
        #print('BRANCH TWO GRADS conv1: ', torch.sum(branch_two_grads['module.conv1.weight']))
        optimizer.zero_grad()

        total_loss = branch_one_loss + branch_two_loss
        total_loss.backward()
        with torch.no_grad():
            for name, parameter in model.named_parameters():
                try:
                    if name not in total_grads.keys():
                        #if name == 'module.conv1.weight':
                        #    print('Total backward --> conv1 grad: (NOT ADDING)', torch.sum(parameter.grad))
                        total_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        #total_grads[name] = torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                    else:
                        #if name == 'module.conv1.weight':
                        #    print('Total backward --> conv1 grad: (ADDING)', torch.sum(parameter.grad))
                        total_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                        #total_grads[name] += torch.mul(copy.deepcopy(parameter.grad), optimizer.param_groups[0]['lr'])
                except:
                    pass
        #print('TOTAL GRADS conv1: ', torch.sum(total_grads['module.conv1.weight']))
        optimizer.step()

    branch_one_train_loss += branch_one_loss.item()
    branch_two_train_loss += branch_two_loss.item()
    _, branch_one_predicted = branch_one_outputs.max(1)
    _, branch_two_predicted = branch_two_outputs.max(1)
    branch_one_total += branch_one_targets.size(0)
    branch_two_total += branch_two_targets.size(0)
    branch_one_correct += branch_one_predicted.eq(branch_one_targets).sum().item()
    branch_two_correct += branch_two_predicted.eq(branch_two_targets).sum().item()

    epoch_count_one += 1
    epoch_count_two += 1

    branch_one_acc = 100.*branch_one_correct/branch_one_total
    branch_two_acc = 100.*branch_two_correct/branch_two_total

    print("total train iters ", len(trainloader), '| time: %.3f sec Cat Loss: %.3f | Cat Acc: %.3f%% (%d/%d) | Dog Loss: %.3f | Dog Acc: %.3f%% (%d/%d)'
        % ((time.time()-start_time), branch_one_train_loss/(1), 
           branch_one_acc, branch_one_correct, branch_one_total, 
           branch_two_train_loss/(1), branch_two_acc, 
           branch_two_correct, branch_two_total))

    return branch_one_acc, branch_two_acc, branch_one_grads, branch_two_grads, total_grads

**Testing the ResNet model**


*   Congestion condition based on F-Score instead of accuracy
*   Dataset has 1000 cats and 5000 dogs



In [14]:
def linear_cong_condition(min_cond, max_cond, epoch, max_epochs):

    condition = min_cond + (max_cond - min_cond) * (epoch / max_epochs)

    return condition

In [22]:
def test_congestion_avoider(start_time, testloader, device, model, optimizer, scheduler, branch_one_grads, branch_two_grads, branch_one_class, branch_two_class, branch_one_criterion, branch_two_criterion, epoch, max_epochs, min_cond, max_cond, min_epochs, mult):
    '''Same as original with additional function to increase the congestion condition linearly over the epochs'''

    model.eval()
    branch_one_test_loss = 0
    branch_two_test_loss = 0
    branch_one_correct = 0
    branch_two_correct = 0
    branch_one_total = 0
    branch_two_total = 0
    branch_one_TP = 0
    branch_one_FP = 0
    branch_one_TN = 0
    branch_one_FN = 0
    branch_two_TP = 0
    branch_two_FP = 0
    branch_two_TN = 0
    branch_two_FN = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            branch_one_targets = get_binary_label(targets, index=branch_one_class)
            branch_two_targets = get_binary_label(targets, index=branch_two_class)
            inputs, branch_one_targets, branch_two_targets = inputs.to(device), branch_one_targets.to(device), branch_two_targets.to(device)
            branch_one_outputs, branch_two_outputs = model(inputs)
            branch_one_loss = branch_one_criterion(branch_one_outputs, branch_one_targets)
            branch_two_loss = branch_two_criterion(branch_two_outputs, branch_two_targets)

            branch_one_test_loss += branch_one_loss.item()
            branch_two_test_loss += branch_two_loss.item()
            _, branch_one_predicted = branch_one_outputs.max(1)
            _, branch_two_predicted = branch_two_outputs.max(1)
            branch_one_total += branch_one_targets.size(0)
            branch_two_total += branch_two_targets.size(0)
            branch_one_correct += branch_one_predicted.eq(branch_one_targets).sum().item()
            branch_two_correct += branch_two_predicted.eq(branch_two_targets).sum().item()

            for target, pred in zip(branch_one_targets, branch_one_predicted):
              if target == 0:
                if pred == 0:
                  branch_one_TN += 1
                else:
                  branch_one_FP += 1
              elif target == 1:
                if pred == 1:
                  branch_one_TP += 1
                else:
                  branch_one_FN += 1
        
            for target, pred in zip(branch_two_targets, branch_two_predicted):
              if target == 0:
                if pred == 0:
                  branch_two_TN += 1
                else:
                  branch_two_FP += 1
              elif target == 1:
                if pred == 1:
                  branch_two_TP += 1
                else:
                  branch_two_FN += 1

        branch_one_val_acc = 100.*branch_one_correct/branch_one_total
        
        if branch_one_TP + branch_one_FP > 0:
          branch_one_precision = 100.*branch_one_TP/(branch_one_TP + branch_one_FP)
        else:
          branch_one_precision = 0
        if branch_one_TP + branch_one_FN > 0:
          branch_one_recall = 100.*branch_one_TP/(branch_one_TP + branch_one_FN)
        else:
          branch_one_recall = 0
        
        branch_two_val_acc = 100.*branch_two_correct/branch_two_total
        
        if branch_two_TP + branch_two_FP > 0:
          branch_two_precision = 100.*branch_two_TP/(branch_two_TP + branch_two_FP)
        else:
          branch_two_precision = 0
        if branch_two_TP + branch_two_FN > 0:
          branch_two_recall = 100.*branch_two_TP/(branch_two_TP + branch_two_FN)
        else:
          branch_two_recall = 0

        try:
          branch_one_F = 2 * branch_one_precision * branch_one_recall / (branch_one_precision + branch_one_recall)
        except:
          branch_one_F = 0
        try:
          branch_two_F = 2 * branch_two_precision * branch_two_recall / (branch_two_precision + branch_two_recall)
        except:
          branch_two_F = 0

        condition = linear_cong_condition(min_cond, max_cond, epoch, max_epochs)

        optimizer, model, boolean_one, boolean_two, branch_one_grads, branch_two_grads = congestion_avoid(model, optimizer, branch_one_precision, branch_two_precision, condition, branch_one_grads, branch_two_grads, min_epochs, mult)
        scheduler.step()

        print("total test iters ", len(testloader), '| time: %.3f sec Cat Loss: %.3f | Cat Acc: %.3f%% (%d/%d) | Dog Loss: %.3f | Dog Acc: %.3f%% (%d/%d)'
        % ((time.time()-start_time), branch_one_test_loss/(batch_idx+1), 
           100.*branch_one_correct/branch_one_total, branch_one_correct, branch_one_total, 
           branch_two_test_loss/(batch_idx+1), 100.*branch_two_correct/branch_two_total, 
           branch_two_correct, branch_two_total))
        
        print('Cat P: : %.3f%% (%d/%d) | Dog P: %.3f%% (%d/%d)'%(branch_one_precision, branch_one_TP, branch_one_TP + branch_one_FP, branch_two_precision, branch_two_TP, branch_two_TP + branch_two_FP))
        print('Cat R: : %.3f%% (%d/%d) | Dog R: %.3f%% (%d/%d)'%(branch_one_recall, branch_one_TP, branch_one_TP + branch_one_FN, branch_two_recall, branch_two_TP, branch_two_TP + branch_two_FN))
        print('Cat F: : %.3f%%         | Dog R: %.3f%%'%(branch_one_F, branch_two_F))

    # RE-EVALUATE THE MODEL ON THE TEST SET AFTER THE WEIGHTS HAVE BEEN UPDATED
    model.eval()
    branch_one_test_loss = 0
    branch_two_test_loss = 0
    branch_one_correct = 0
    branch_two_correct = 0
    branch_one_total = 0
    branch_two_total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            branch_one_targets = get_binary_label(targets, index=branch_one_class)
            branch_two_targets = get_binary_label(targets, index=branch_two_class)
            inputs, branch_one_targets, branch_two_targets = inputs.to(device), branch_one_targets.to(device), branch_two_targets.to(device)
            branch_one_outputs, branch_two_outputs = model(inputs)
            branch_one_loss = branch_one_criterion(branch_one_outputs, branch_one_targets)
            branch_two_loss = branch_two_criterion(branch_two_outputs, branch_two_targets)

            branch_one_test_loss += branch_one_loss.item()
            branch_two_test_loss += branch_two_loss.item()
            _, branch_one_predicted = branch_one_outputs.max(1)
            _, branch_two_predicted = branch_two_outputs.max(1)
            branch_one_total += branch_one_targets.size(0)
            branch_two_total += branch_two_targets.size(0)
            branch_one_correct += branch_one_predicted.eq(branch_one_targets).sum().item()
            branch_two_correct += branch_two_predicted.eq(branch_two_targets).sum().item()

        print("total test iters ", len(testloader), '| time: %.3f sec Cat Loss: %.3f | Cat Acc: %.3f%% (%d/%d) | Dog Loss: %.3f | Dog Acc: %.3f%% (%d/%d)'
        % ((time.time()-start_time), branch_one_test_loss/(batch_idx+1), 
           100.*branch_one_correct/branch_one_total, branch_one_correct, branch_one_total, 
           branch_two_test_loss/(batch_idx+1), 100.*branch_two_correct/branch_two_total, 
           branch_two_correct, branch_two_total))


    return optimizer, branch_one_val_acc, branch_two_val_acc, branch_one_precision, branch_two_precision, branch_one_recall, branch_two_recall, branch_one_F, branch_two_F, boolean_one, boolean_two, branch_one_grads, branch_two_grads

**Producing the results of the training process**

*   The standard LR schedule is the Cyclic LR with decaying peaks



In [16]:
import time
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR
import torch.backends.cudnn as cudnn

In [17]:
def get_cong_avoidance_results(branch_one_class=5, branch_two_class=9, class_sizes_train = [625,625,625,625,625,1000,625,625,625,5000], class_sizes_test = [125,125,125,125,125,1000,125,125,125,1000], epochs=100, min_cond=0.95, max_cond = 0.99, mult_factor=1, lr=0.1, min_epochs = 5):

    '''Allow the congestion condition to change linearly over time '''

    branch_one_grads = {}
    branch_two_grads = {}
    epoch_count_one = 0
    epoch_count_two = 0

    # IMPORT DATA
    trainset, trainloader, testset, testloader = create_CIFAR_data()
    
    # CREATE DATASET WITH CLASS SIZES (NOW CAT DATA IS 10x SMALLER)
    trainset = create_unbalanced_CIFAR10(trainset, class_sizes = class_sizes_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    testset = create_unbalanced_CIFAR10(testset, class_sizes = class_sizes_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=2)

    # CREATE MODEL
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = ResNetSplit18Shared()
    model = model.to(device)
    if device == 'cuda':
        print('CUDA device used...')
        model = torch.nn.DataParallel(model)
        cudnn.benchmark = True
    # CREATE LOSS OF EACH BRANCH
    branch_one_samples = [class_sizes_train[branch_two_class] + 5000 / 2, class_sizes_train[branch_one_class]]
    branch_one_weights = torch.tensor([(sum(branch_one_samples) - x)/sum(branch_one_samples) for x in branch_one_samples])
    branch_two_samples = [class_sizes_train[branch_one_class] + 5000 / 2, class_sizes_train[branch_two_class]]
    branch_two_weights = torch.tensor([(sum(branch_two_samples) - x)/sum(branch_two_samples) for x in branch_two_samples])
    branch_one_weights, branch_two_weights = branch_one_weights.to(device), branch_two_weights.to(device)
    
    branch_one_criterion = nn.CrossEntropyLoss(weight=branch_one_weights)
    branch_two_criterion = nn.CrossEntropyLoss(weight=branch_two_weights)
    #branch_one_criterion = nn.CrossEntropyLoss()
    #branch_two_criterion = nn.CrossEntropyLoss()
    # CREATE MODEL OPTIMIZER
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0, weight_decay=5e-4)
    scheduler = CyclicLR(optimizer, base_lr=0.0001, max_lr=lr, step_size_up=10, mode="triangular2")

    # BEGIN RECORDING THE TIME
    start_time = time.time()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    branch_one_train_accuracies = []
    branch_two_train_accuracies = []
    branch_one_train_P = []
    branch_two_train_P = []
    branch_one_train_R = []
    branch_two_train_R = []
    branch_one_train_F = []
    branch_two_train_F = []

    branch_one_test_accuracies = []
    branch_two_test_accuracies = []
    branch_one_test_P = []
    branch_two_test_P = []
    branch_one_test_R = []
    branch_two_test_R = []
    branch_one_test_F = []
    branch_two_test_F = []

    branch_one_condition = []
    branch_two_condition = []

    boolean_one = False
    boolean_two = False

    for epoch in range(epochs):
        print('\n********** EPOCH {} **********'.format(epoch + 1))
        print('Learning rate: ', optimizer.param_groups[0]['lr'])
        branch_one_train_acc, branch_two_train_acc, b1_train_P, b2_train_P, b1_train_R, b2_train_R, b1_train_F, b2_train_F, branch_one_grads, branch_two_grads = train_congestion_avoider(trainloader, device, model, optimizer, branch_one_criterion, branch_two_criterion, branch_one_class, branch_two_class, boolean_one, boolean_two, branch_one_grads, branch_two_grads)
        #print('\nBRANCH ONE GRADS:')
        #for key, value in branch_one_grads.items():
        #    print(key, torch.sum(value))
        #print('\nBRANCH TWO GRADS:')
        #for key, value in branch_two_grads.items():
        #    print(key, torch.sum(value))
        branch_one_train_accuracies.append(branch_one_train_acc)
        branch_two_train_accuracies.append(branch_two_train_acc)
        branch_one_train_P.append(b1_train_P)
        branch_two_train_P.append(b2_train_P)
        branch_one_train_R.append(b1_train_R)
        branch_two_train_R.append(b2_train_R)
        branch_one_train_F.append(b1_train_F)
        branch_two_train_F.append(b2_train_F)
        #print('Weight after training (SHARED): ', torch.sum(model.module.conv1.weight))
        #print('Weight after training (BRANCH 1): ', torch.sum(model.module.branch1layer3[0].conv1.weight))
        #print('Weight after training (BRANCH 2): ', torch.sum(model.module.branch2layer3[0].conv1.weight))
        optimizer, branch_one_val_acc, branch_two_val_acc, b1_test_P, b2_test_P, b1_test_R, b2_test_R, b1_test_F, b2_test_F, boolean_one, boolean_two, branch_one_grads, branch_two_grads = test_congestion_avoider(start_time, testloader, device, model, optimizer, scheduler, branch_one_grads, branch_two_grads, branch_one_class, branch_two_class, branch_one_criterion, branch_two_criterion, epoch, epochs, min_cond, max_cond, min_epochs, mult_factor)
        #print('Weight after scheduler (SHARED): ', torch.sum(model.module.conv1.weight))
        #print('Weight after training (BRANCH 1): ', torch.sum(model.module.branch1layer3[0].conv1.weight))
        #print('Weight after training (BRANCH 2): ', torch.sum(model.module.branch2layer3[0].conv1.weight))
        branch_one_test_accuracies.append(branch_one_val_acc)
        branch_two_test_accuracies.append(branch_two_val_acc)
        branch_one_test_P.append(b1_test_P)
        branch_two_test_P.append(b2_test_P)
        branch_one_test_R.append(b1_test_R)
        branch_two_test_R.append(b2_test_R)
        branch_one_test_F.append(b1_test_F)
        branch_two_test_F.append(b2_test_F)

        branch_one_condition.append(boolean_one)
        branch_two_condition.append(boolean_two)

    return branch_one_train_accuracies, branch_two_train_accuracies, branch_one_train_P, branch_two_train_P, branch_one_train_R, branch_two_train_R, branch_one_train_F, branch_two_train_F, branch_one_test_accuracies, branch_two_test_accuracies, branch_one_test_P, branch_two_test_P, branch_one_test_R, branch_two_test_R, branch_one_test_F, branch_two_test_F, branch_one_condition, branch_two_condition

**1000 / 5000 dataset (Dogs and Trucks) -- Baseline**


*   Baseline result using dogs and trucks
*   Dataset has 1000 images of dogs



In [18]:
start_time = time.time()
branch_one_grads = {}
branch_two_grads = {}
epoch_count_one = 0
epoch_count_two = 0

class_sizes_train = [625,625,625,625,625,1000,625,625,625,5000]
class_sizes_test = [125,125,125,125,125,1000,125,125,125,1000]

b1_train_A, b2_train_A, b1_train_P, b2_train_P, b1_train_R, b2_train_R, b1_train_F, b2_train_F, b1_test_A, b2_test_A, b1_test_P, b2_test_P, b1_test_R, b2_test_R, b1_test_F, b2_test_F, b1_condition, b2_condition = get_cong_avoidance_results(branch_one_class=5, branch_two_class=9, class_sizes_train=class_sizes_train, class_sizes_test=class_sizes_test, epochs=100, min_cond=0, max_cond = 0, mult_factor=0, lr=0.1, min_epochs = 0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
[5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
[ 625  625  625  625  625 1000  625  625  625 5000]
[1000 1000 1000 1000 1000 1000 1000 1000 1000 1000]
[ 125  125  125  125  125 1000  125  125  125 1000]
CUDA device used...

********** EPOCH 1 **********
Learning rate:  0.0001


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


total train iters  86 | time: 14.991 sec Cat Loss: 0.674 | Cat Acc: 58.837% (4990/8481) | Dog Loss: 0.673 | Dog Acc: 60.207% (5129/8519)
Cat P: : 16.600% (619/3729) | Dog P: 70.694% (2750/3890)
Cat R: : 61.900% (619/1000) | Dog R: 55.000% (2750/5000)
Cat F: : 26.179%         | Dog F: 61.867%
No condition is met ..... 0.00%
total test iters  24 | time: 16.236 sec Cat Loss: 0.670 | Cat Acc: 66.967% (2009/3000) | Dog Loss: 0.623 | Dog Acc: 67.033% (2011/3000)
Cat P: : 50.391% (580/1151) | Dog P: 50.402% (690/1369)
Cat R: : 58.000% (580/1000) | Dog R: 69.000% (690/1000)
Cat F: : 53.928%         | Dog R: 58.252%
total test iters  24 | time: 17.006 sec Cat Loss: 0.675 | Cat Acc: 66.967% (2009/3000) | Dog Loss: 0.625 | Dog Acc: 67.033% (2011/3000)

********** EPOCH 2 **********
Learning rate:  0.010090000000000009
total train iters  86 | time: 11.443 sec Cat Loss: 0.637 | Cat Acc: 71.447% (6058/8479) | Dog Loss: 0.585 | Dog Acc: 73.712% (6281/8521)
Cat P: : 25.842% (760/2941) | Dog P: 79.200%

In [19]:
import pickle
print(sum(b1_condition))
print(sum(b2_condition))
with open('cyclicLR_Baseline_Dogs_Trucks_1000_5000.pickle', 'wb') as file:
    pickle.dump(b1_train_A, file)
    pickle.dump(b2_train_A, file)
    pickle.dump(b1_train_P, file)
    pickle.dump(b2_train_P, file)
    pickle.dump(b1_train_R, file)
    pickle.dump(b2_train_R, file)
    pickle.dump(b1_train_F, file)
    pickle.dump(b2_train_F, file)
    pickle.dump(b1_test_A, file)
    pickle.dump(b2_test_A, file)
    pickle.dump(b1_test_P, file)
    pickle.dump(b2_test_P, file)
    pickle.dump(b1_test_R, file)
    pickle.dump(b2_test_R, file)
    pickle.dump(b1_test_F, file)
    pickle.dump(b2_test_F, file)
    pickle.dump(b1_condition, file)
    pickle.dump(b2_condition, file)

0
0


In [20]:
start_time = time.time()
branch_one_grads = {}
branch_two_grads = {}
epoch_count_one = 0
epoch_count_two = 0

class_sizes_train = [625,625,625,625,625,5000,625,625,625,1000]
class_sizes_test = [125,125,125,125,125,1000,125,125,125,1000]

b1_train_A, b2_train_A, b1_train_P, b2_train_P, b1_train_R, b2_train_R, b1_train_F, b2_train_F, b1_test_A, b2_test_A, b1_test_P, b2_test_P, b1_test_R, b2_test_R, b1_test_F, b2_test_F, b1_condition, b2_condition = get_cong_avoidance_results(branch_one_class=5, branch_two_class=9, class_sizes_train=class_sizes_train, class_sizes_test=class_sizes_test, epochs=100, min_cond=0, max_cond = 0, mult_factor=0, lr=0.1, min_epochs = 0)

Files already downloaded and verified
Files already downloaded and verified
[5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
[ 625  625  625  625  625 5000  625  625  625 1000]
[1000 1000 1000 1000 1000 1000 1000 1000 1000 1000]
[ 125  125  125  125  125 1000  125  125  125 1000]
CUDA device used...

********** EPOCH 1 **********
Learning rate:  0.0001
total train iters  86 | time: 10.052 sec Cat Loss: 0.712 | Cat Acc: 55.124% (4674/8479) | Dog Loss: 0.660 | Dog Acc: 67.433% (5746/8521)
Cat P: : 65.016% (2587/3979) | Dog P: 19.782% (581/2937)
Cat R: : 51.740% (2587/5000) | Dog R: 58.100% (581/1000)
Cat F: : 57.623%         | Dog F: 29.515%
No condition is met ..... 0.00%
total test iters  24 | time: 11.136 sec Cat Loss: 0.630 | Cat Acc: 64.400% (1932/3000) | Dog Loss: 0.603 | Dog Acc: 73.333% (2200/3000)
Cat P: : 47.424% (626/1320) | Dog P: 59.058% (652/1104)
Cat R: : 62.600% (626/1000) | Dog R: 65.200% (652/1000)
Cat F: : 53.966%         | Dog R: 61.977%
total test iters  24 | time

In [21]:
print(sum(b1_condition))
print(sum(b2_condition))
with open('cyclicLR_Baseline_Dogs_Trucks_5000_1000.pickle', 'wb') as file:
    pickle.dump(b1_train_A, file)
    pickle.dump(b2_train_A, file)
    pickle.dump(b1_train_P, file)
    pickle.dump(b2_train_P, file)
    pickle.dump(b1_train_R, file)
    pickle.dump(b2_train_R, file)
    pickle.dump(b1_train_F, file)
    pickle.dump(b2_train_F, file)
    pickle.dump(b1_test_A, file)
    pickle.dump(b2_test_A, file)
    pickle.dump(b1_test_P, file)
    pickle.dump(b2_test_P, file)
    pickle.dump(b1_test_R, file)
    pickle.dump(b2_test_R, file)
    pickle.dump(b1_test_F, file)
    pickle.dump(b2_test_F, file)
    pickle.dump(b1_condition, file)
    pickle.dump(b2_condition, file)

0
0


**1000/5000 (dogs/trucks) -- Congestion based on precision -- 90% -- 50% -- 5 epochs**

In [23]:
start_time = time.time()
branch_one_grads = {}
branch_two_grads = {}
epoch_count_one = 0
epoch_count_two = 0

class_sizes_train = [625,625,625,625,625,1000,625,625,625,5000]
class_sizes_test = [125,125,125,125,125,1000,125,125,125,1000]

b1_train_A, b2_train_A, b1_train_P, b2_train_P, b1_train_R, b2_train_R, b1_train_F, b2_train_F, b1_test_A, b2_test_A, b1_test_P, b2_test_P, b1_test_R, b2_test_R, b1_test_F, b2_test_F, b1_condition, b2_condition = get_cong_avoidance_results(branch_one_class=5, branch_two_class=9, class_sizes_train=class_sizes_train, class_sizes_test=class_sizes_test, epochs=100, min_cond=0.9, max_cond = 0.9, mult_factor=0.5, lr=0.1, min_epochs = 5)

Files already downloaded and verified
Files already downloaded and verified
[5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
[ 625  625  625  625  625 1000  625  625  625 5000]
[1000 1000 1000 1000 1000 1000 1000 1000 1000 1000]
[ 125  125  125  125  125 1000  125  125  125 1000]
CUDA device used...

********** EPOCH 1 **********
Learning rate:  0.0001
total train iters  86 | time: 9.798 sec Cat Loss: 0.683 | Cat Acc: 58.530% (4961/8476) | Dog Loss: 0.675 | Dog Acc: 60.218% (5133/8524)
Cat P: : 15.875% (585/3685) | Dog P: 69.123% (2908/4207)
Cat R: : 58.500% (585/1000) | Dog R: 58.160% (2908/5000)
Cat F: : 24.973%         | Dog F: 63.169%
No condition is met ..... 90.00%
total test iters  24 | time: 10.875 sec Cat Loss: 0.601 | Cat Acc: 63.200% (1896/3000) | Dog Loss: 0.571 | Dog Acc: 69.267% (2078/3000)
Cat P: : 46.570% (706/1516) | Dog P: 53.356% (620/1162)
Cat R: : 70.600% (706/1000) | Dog R: 62.000% (620/1000)
Cat F: : 56.121%         | Dog R: 57.354%
total test iters  24 | time

In [24]:
print(sum(b1_condition))
print(sum(b2_condition))
with open('cond900_recall_mult050_5epochs_1000dog_5000truck.pickle', 'wb') as file:
    pickle.dump(b1_train_A, file)
    pickle.dump(b2_train_A, file)
    pickle.dump(b1_train_P, file)
    pickle.dump(b2_train_P, file)
    pickle.dump(b1_train_R, file)
    pickle.dump(b2_train_R, file)
    pickle.dump(b1_train_F, file)
    pickle.dump(b2_train_F, file)
    pickle.dump(b1_test_A, file)
    pickle.dump(b2_test_A, file)
    pickle.dump(b1_test_P, file)
    pickle.dump(b2_test_P, file)
    pickle.dump(b1_test_R, file)
    pickle.dump(b2_test_R, file)
    pickle.dump(b1_test_F, file)
    pickle.dump(b2_test_F, file)
    pickle.dump(b1_condition, file)
    pickle.dump(b2_condition, file)

19
1


In [25]:
start_time = time.time()
branch_one_grads = {}
branch_two_grads = {}
epoch_count_one = 0
epoch_count_two = 0

class_sizes_train = [625,625,625,625,625,1000,625,625,625,5000]
class_sizes_test = [125,125,125,125,125,1000,125,125,125,1000]

b1_train_A, b2_train_A, b1_train_P, b2_train_P, b1_train_R, b2_train_R, b1_train_F, b2_train_F, b1_test_A, b2_test_A, b1_test_P, b2_test_P, b1_test_R, b2_test_R, b1_test_F, b2_test_F, b1_condition, b2_condition = get_cong_avoidance_results(branch_one_class=5, branch_two_class=9, class_sizes_train=class_sizes_train, class_sizes_test=class_sizes_test, epochs=100, min_cond=0.8, max_cond = 0.8, mult_factor=0.5, lr=0.1, min_epochs = 5)

Files already downloaded and verified
Files already downloaded and verified
[5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
[ 625  625  625  625  625 1000  625  625  625 5000]
[1000 1000 1000 1000 1000 1000 1000 1000 1000 1000]
[ 125  125  125  125  125 1000  125  125  125 1000]
CUDA device used...

********** EPOCH 1 **********
Learning rate:  0.0001
total train iters  86 | time: 9.794 sec Cat Loss: 0.691 | Cat Acc: 57.236% (4853/8479) | Dog Loss: 0.671 | Dog Acc: 61.753% (5262/8521)
Cat P: : 15.228% (575/3776) | Dog P: 68.838% (3181/4621)
Cat R: : 57.500% (575/1000) | Dog R: 63.620% (3181/5000)
Cat F: : 24.079%         | Dog F: 66.126%
No condition is met ..... 80.00%
total test iters  24 | time: 10.775 sec Cat Loss: 0.640 | Cat Acc: 64.200% (1926/3000) | Dog Loss: 0.557 | Dog Acc: 71.633% (2149/3000)
Cat P: : 47.342% (659/1392) | Dog P: 57.170% (594/1039)
Cat R: : 65.900% (659/1000) | Dog R: 59.400% (594/1000)
Cat F: : 55.100%         | Dog R: 58.264%
total test iters  24 | time

In [26]:
print(sum(b1_condition))
print(sum(b2_condition))
with open('cond800_recall_mult050_5epochs_1000dog_5000truck.pickle', 'wb') as file:
    pickle.dump(b1_train_A, file)
    pickle.dump(b2_train_A, file)
    pickle.dump(b1_train_P, file)
    pickle.dump(b2_train_P, file)
    pickle.dump(b1_train_R, file)
    pickle.dump(b2_train_R, file)
    pickle.dump(b1_train_F, file)
    pickle.dump(b2_train_F, file)
    pickle.dump(b1_test_A, file)
    pickle.dump(b2_test_A, file)
    pickle.dump(b1_test_P, file)
    pickle.dump(b2_test_P, file)
    pickle.dump(b1_test_R, file)
    pickle.dump(b2_test_R, file)
    pickle.dump(b1_test_F, file)
    pickle.dump(b2_test_F, file)
    pickle.dump(b1_condition, file)
    pickle.dump(b2_condition, file)

7
0
