In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 

num_cores = 8
torch.set_num_interop_threads(num_cores) # Inter-op parallelism
torch.set_num_threads(num_cores) # Intra-op parallelism

In [2]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

coarser = Lambda(lambda y: torch.tensor([0 if y < 2 or y > 7 else 1, \
                                         0 if y == 0 else (1 if y == 1 else \
                                                          (2 if y == 2 or y == 3 else \
                                                          (3 if y == 4 else \
                                                          (4 if y == 5 else \
                                                          (5 if y == 6 or y == 7 else 6))))),
                                         int(y)]))

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform, target_transform = coarser)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform, target_transform = coarser)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_cores)

In [4]:
class BCNN3(nn.Module):
    def __init__(self, alpha, beta, gamma, learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3):
        
        super().__init__()
        self.trainloader = trainloader
        self.testloader = testloader
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.nesterov = nesterov
        self.alphas = alpha
        self.betas = beta
        self.gammas = gamma
        self.alpha = self.alphas[0]
        self.beta = self.betas[0]
        self.gamma = self.gammas[0]
        self.activation = F.relu
        self.num_c_1 = num_class_c1
        self.num_c_2 = num_class_c2
        self.num_c_3 = num_class_c3
        self.epochs = epochs
        self.epoch_error = 0.

        self.layer1  = nn.Conv2d(3, 64, (3,3), padding = 'same')
        self.layer2  = nn.BatchNorm2d(64)
        self.layer3  = nn.Conv2d(64, 64, (3,3), padding = 'same')
        self.layer4  = nn.BatchNorm2d(64)
        self.layer5  = nn.MaxPool2d((2,2), stride = (2,2))

        self.layer6  = nn.Conv2d(64, 128, (3,3), padding = 'same')
        self.layer7  = nn.BatchNorm2d(128)
        self.layer8  = nn.Conv2d(128, 128, (3,3), padding = 'same')
        self.layer9  = nn.BatchNorm2d(128)
        self.layer10 = nn.MaxPool2d((2,2), stride = (2,2))

        self.layerb11 = nn.Linear(8*8*128, 256)
        self.layerb12 = nn.BatchNorm1d(256)
        self.layerb13 = nn.Dropout(0.5)
        self.layerb14 = nn.Linear(256, 256)
        self.layerb15 = nn.BatchNorm1d(256)
        self.layerb16 = nn.Dropout(0.5)
        self.layerb17 = nn.Linear(256, self.num_c_1)

        self.layer11 = nn.Conv2d(128, 256, (3,3), padding = 'same')
        self.layer12 = nn.BatchNorm2d(256)
        self.layer13 = nn.Conv2d(256, 256, (3,3), padding = 'same')
        self.layer14 = nn.BatchNorm2d(256)
        self.layer15 = nn.MaxPool2d((2,2), stride = (2,2))

        self.layerb21 = nn.Linear(4*4*256, 512)
        self.layerb22 = nn.BatchNorm1d(512)
        self.layerb23 = nn.Dropout(0.5)
        self.layerb24 = nn.Linear(512, 512)
        self.layerb25 = nn.BatchNorm1d(512)
        self.layerb26 = nn.Dropout(0.5)
        self.layerb27 = nn.Linear(512, self.num_c_2)

        self.layer16 = nn.Conv2d(256, 512, (3,3), padding = 'same')
        self.layer17 = nn.BatchNorm2d(512)
        self.layer18 = nn.Conv2d(512, 512, (3,3), padding = 'same')
        self.layer19 = nn.BatchNorm2d(512)
        self.layer20 = nn.MaxPool2d((2,2), stride = (2,2))

        self.layerb31 = nn.Linear(2*2*512, 1024)
        self.layerb32 = nn.BatchNorm1d(1024)
        self.layerb33 = nn.Dropout(0.5)
        self.layerb34 = nn.Linear(1024, 1024)
        self.layerb35 = nn.BatchNorm1d(1024)
        self.layerb36 = nn.Dropout(0.5)
        self.layerb37 = nn.Linear(1024, self.num_c_3)

        self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[0], 
                                   momentum = self.momentum, nesterov = self.nesterov)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):

        # block 1
        z = self.layer1(x)
        z = self.activation(z)
        z = self.layer2(z)
        z = self.layer3(z)
        z = self.activation(z)
        z = self.layer4(z)
        z = self.layer5(z)

        # block 2
        z = self.layer6(z)
        z = self.activation(z)
        z = self.layer7(z)
        z = self.layer8(z)
        z = self.activation(z)
        z = self.layer9(z)
        z = self.layer10(z)

        # branch 1
        b1 = torch.flatten(z, start_dim = 1)
        b1 = self.layerb11(b1)
        b1 = self.activation(b1)
        b1 = self.layerb12(b1)
        b1 = self.layerb13(b1)
        b1 = self.layerb14(b1)
        b1 = self.activation(b1)
        b1 = self.layerb15(b1)
        b1 = self.layerb16(b1)
        b1 = self.layerb17(b1)

        # block 3
        z = self.layer11(z)
        z = self.activation(z)
        z = self.layer12(z)
        z = self.layer13(z)
        z = self.activation(z)
        z = self.layer14(z)
        z = self.layer15(z)
        
        # branch 2
        b2 = torch.flatten(z, start_dim = 1)
        b2 = self.layerb21(b2)
        b2 = self.activation(b2)
        b2 = self.layerb22(b2)
        b2 = self.layerb23(b2)
        b2 = self.layerb24(b2)
        b2 = self.activation(b2)
        b2 = self.layerb25(b2)
        b2 = self.layerb26(b2)
        b2 = self.layerb27(b2)

        # block 4
        z = self.layer16(z)
        z = self.activation(z)
        z = self.layer17(z)
        z = self.layer18(z)
        z = self.activation(z)
        z = self.layer19(z)
        z = self.layer20(z)

        # branch 3
        b3 = torch.flatten(z, start_dim = 1)
        b3 = self.layerb31(b3)
        b3 = self.activation(b3)
        b3 = self.layerb32(b3)
        b3 = self.layerb33(b3)
        b3 = self.layerb34(b3)
        b3 = self.activation(b3)
        b3 = self.layerb35(b3)
        b3 = self.layerb36(b3)
        b3 = self.layerb37(b3)
        
        
        return b1, b2, b3


    def train(self):

        for epoch in np.arange(self.epochs):

            self.epoch_error
            
            if epoch == 9:
                self.alpha = self.alphas[1]
                self.beta = self.betas[1]
                self.gamma = self.gammas[1]
            elif epoch == 19:
                self.alpha = self.alphas[2]
                self.beta = self.betas[2]
                self.gamma = self.gammas[2]
            elif epoch == 29:
                self.alpha = self.alphas[3]
                self.beta = self.betas[3]
                self.gamma = self.gammas[3]
            elif epoch == 41:
                self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[1], 
                                   momentum = self.momentum, nesterov = self.nesterov)
            elif epoch == 51:
                self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[2], 
                                   momentum = self.momentum, nesterov = self.nesterov)

            running_loss = 0.
            
            for iter, (batch, labels) in enumerate(self.trainloader):
                self.optimizer.zero_grad()
                predict = self(batch)
                loss =  self.alpha * self.criterion(predict[0], labels[:,0]) + \
                        self.beta * self.criterion(predict[1], labels[:,1]) + \
                        self.gamma * self.criterion(predict[2], labels[:,2])

                loss.backward()
                self.optimizer.step()


                running_loss += (loss.item() - running_loss) / (iter+1)
                if (iter+1) & 2047 == 0:
                    print(f'[{epoch + 1}] loss: {running_loss :.3f}')
                    running_loss = 0.

    def test(self):
        print("To implement")
            

In [5]:
alpha = [0.98, 0.1, 0.1, 0.]
beta = [0.01, 0.8, 0.2, 0.]
gamma = [0.01, 0.1, 0.1, 1.]
learning_rate = [3e-3, 5e-4, 1e-4]
momentum = 0.9
nesterov = True
epochs = 60
num_class_c1 = 2
num_class_c2 = 7
num_class_c3 = 10

#--- coarse 1 classes ---
classes_c_1 = ('transport', 'animal')
#--- coarse 2 classes ---
classes_c_2 = ('sky', 'water', 'road', 'bird', 'reptile', 'pet', 'medium')
#--- fine classes ---
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [6]:
cnn = BCNN3(alpha, beta, gamma, learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3)

In [7]:
cnn.train()

KeyboardInterrupt: 