In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 
from abc import ABC

num_cores = 8
torch.set_num_interop_threads(num_cores) # Inter-op parallelism
torch.set_num_threads(num_cores) # Intra-op parallelism

In [2]:
def c3_to_c1(y):
    if y < 2 or y > 7:
        return 0
    return 1

def c3_to_c2(y):
    match y:
        case 0:
            return 0
        case 1:
            return 2
        case 2:
            return 3
        case 3:
            return 5
        case 4:
            return 6
        case 5:
            return 5
        case 6:
            return 4
        case 7:
            return 6
        case 8:
            return 1
        case _:
            return 2

def c2_to_c1(y):
    if y < 3:
        return 0
    return 1

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

coarser = Lambda(lambda y: torch.tensor([c3_to_c1(y), c3_to_c2(y), int(y)]))

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform, target_transform = coarser)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform, target_transform = coarser)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_cores)

## NA_Layer2

In [4]:
"""class NA_Layer2(ABC, nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi = self.layer_f(x[1, :]) + self.layer_i(x[0, :])

        return torch.stack((yf, yi))

    def infinitesimal_gradient(self):
        self.layer_i.weight.grad = self.layer_f.weight.grad
        self.layer_f.weight.grad = None

        if bias:
            self.layer_i.bias.grad = self.layer_f.bias.grad
            self.layer_f.bias.grad = None

class NA_Linear2(NA_Layer2):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i = nn.Linear(in_features, out_features, bias, device, dtype)

class NA_Conv2d2(NA_Layer2):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 
                padding_mode='zeros', device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)"""

"class NA_Layer2(ABC, nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        yf = self.layer_f(x[0, :])\n        yi = self.layer_f(x[1, :]) + self.layer_i(x[0, :])\n\n        return torch.stack((yf, yi))\n\n    def infinitesimal_gradient(self):\n        self.layer_i.weight.grad = self.layer_f.weight.grad\n        self.layer_f.weight.grad = None\n\n        if bias:\n            self.layer_i.bias.grad = self.layer_f.bias.grad\n            self.layer_f.bias.grad = None\n\nclass NA_Linear2(NA_Layer2):\n    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):\n\n        super().__init__()\n        self.bias = bias\n        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)\n        self.layer_i = nn.Linear(in_features, out_features, bias, device, dtype)\n\nclass NA_Conv2d2(NA_Layer2):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dila

## NA_Layer3

In [12]:
class NA_Layer(ABC, nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        pass

class NA_combiner3(NA_Layer):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi1 = self.layer_f(x[1, :]) + self.layer_i1(x[0, :])
        yi2 = self.layer_f(x[2, :]) + self.layer_i1(x[1, :]) + self.layer_i2(x[0, :])

        return torch.stack((yf, yi1, yi2))

    def infinitesimal_gradient(self, i):
        if i == 1:
            self.layer_i1.weight.grad = self.layer_f.weight.grad
            self.layer_f.weight.grad = None
    
            if bias:
                self.layer_i1.bias.grad = self.layer_f.bias.grad
                self.layer_f.bias.grad = None
        
        elif i == 2:
            self.layer_i2.weight.grad = self.layer_f.weight.grad
            self.layer_i1.weight.grad = None
            self.layer_f.weight.grad = None
    
            if bias:
                self.layer_i2.bias.grad = self.layer_f.bias.grad
                self.layer_i1.bias.grad = None
                self.layer_f.bias.grad = None
            
class NA_separate3(NA_Layer):

    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi1 = self.layer_i1(x[1, :])
        yi2 = self.layer_i2(x[2, :])

        return torch.stack((yf, yi1, yi2))

class NA_independent(NA_Layer):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        y = []
        for i in np.arange(x.size(0)):
            y.append(self.layer(x[i, :]))

        return torch.stack(y)

class NA_Linear3(NA_combiner3):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i1 = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i2 = nn.Linear(in_features, out_features, bias, device, dtype)

class NA_Conv2d3(NA_combiner3):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 
                padding_mode='zeros', device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)

class NA_BatchNorm2d3(NA_separate3):

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True,
                 track_running_stats: bool = True, device=None, dtype=None):

        super().__init__()
        self.layer_f = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i1 = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i2 = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)

class NA_BatchNorm1d3(NA_separate3):

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True,
                 track_running_stats: bool = True, device=None, dtype=None):

        super().__init__()
        self.layer_f = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i1 = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i2 = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)

class NA_MaxPool2d(NA_independent):

    def __init__(self, kernel_size, stride = None, padding = 0, dilation = 1, return_indices = False, ceil_mode = False):
        
        super().__init__()
        self.layer = nn.MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

class NA_Dropout(NA_independent):

     def __init__(self, p: float = 0.5, inplace: bool = False):
         
         super().__init__()
         self.layer = nn.Dropout(p, inplace)

In [13]:
class NA_BCNN3(nn.Module):
    def __init__(self, learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3, labels_c_1, labels_c_2, labels_c_3, 
                 every_print = 512, training_size = 50000):

        super().__init__()
        self.trainloader = trainloader
        self.testloader = testloader
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.nesterov = nesterov
        self.activation = F.relu
        self.class_levels = 3
        self.max_children = 4
        self.num_c_1 = num_class_c1
        self.num_c_2 = num_class_c2
        self.num_c_3 = num_class_c3
        self.epochs = epochs
        self.labels_c_1 = labels_c_1
        self.labels_c_2 = labels_c_2
        self.labels_c_3 = labels_c_3
        self.every_print = every_print - 1 # assumed power of 2, -1 to make the mask
        self.track_size = int( training_size / batch_size / every_print ) 

        self.layer1  = NA_Conv2d3(3, 64, (3,3), padding = 'same')
        self.layer2  = NA_BatchNorm2d3(64)
        self.layer3  = NA_Conv2d3(64, 64, (3,3), padding = 'same')
        self.layer4  = NA_BatchNorm2d3(64)
        self.layer5  = NA_MaxPool2d((2,2), stride = (2,2))

        self.layer6  = NA_Conv2d3(64, 128, (3,3), padding = 'same')
        self.layer7  = NA_BatchNorm2d3(128)
        self.layer8  = NA_Conv2d3(128, 128, (3,3), padding = 'same')
        self.layer9  = NA_BatchNorm2d3(128)
        self.layer10 = NA_MaxPool2d((2,2), stride = (2,2))

        self.layerb11 = NA_Linear3(8*8*128, 256)
        self.layerb12 = NA_BatchNorm1d3(256)
        self.layerb13 = NA_Dropout(0.5)
        self.layerb14 = NA_Linear3(256, 256)
        self.layerb15 = NA_BatchNorm1d3(256)
        self.layerb16 = NA_Dropout(0.5)
        self.layerb17 = NA_Linear3(256, self.max_children)

        self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[0], 
                                   momentum = self.momentum, nesterov = self.nesterov)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):

        # block 1
        z = self.layer1(x)
        z = self.activation(z)
        z = self.layer2(z)
        z = self.layer3(z)
        z = self.activation(z)
        z = self.layer4(z)
        z = self.layer5(z)

        # block 2
        z = self.layer6(z)
        z = self.activation(z)
        z = self.layer7(z)
        z = self.layer8(z)
        z = self.activation(z)
        z = self.layer9(z)
        z = self.layer10(z)

        # branch 1
        b1 = torch.flatten(z, start_dim = 2)
        b1 = self.layerb11(b1)
        b1 = self.activation(b1)
        b1 = self.layerb12(b1)
        b1 = self.layerb13(b1)
        b1 = self.layerb14(b1)
        b1 = self.activation(b1)
        b1 = self.layerb15(b1)
        b1 = self.layerb16(b1)
        b1 = self.layerb17(b1)

        return b1

In [14]:
learning_rate = [3e-3, 5e-4, 1e-4]
momentum = 0.9
nesterov = True
epochs = 60
num_class_c1 = 2
num_class_c2 = 7
num_class_c3 = 10
every_print = 32

#--- coarse 1 classes ---
labels_c_1 = ('transport', 'animal')
#--- coarse 2 classes ---
labels_c_2 = ('sky', 'water', 'road', 'bird', 'reptile', 'pet', 'medium')
#--- fine classes ---
labels_c_3 = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [15]:
img, _ = next(iter(testloader))
img = torch.stack((img, torch.zeros_like(img), torch.zeros_like(img)))

In [16]:
cnn = NA_BCNN3(learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3, labels_c_1, labels_c_2, labels_c_3, every_print)

z = cnn(img)

In [17]:
z.shape

torch.Size([3, 128, 4])