In [1]:
#####################################################################
############################## imports ##############################
#####################################################################

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import ssl
import torch.nn as nn
from torchvision import datasets
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import time
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#####################################################################
########################## HyperParameters ##########################
#####################################################################

BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 30
ARCH_WIDTH = 1
NUM_CLASSES = 100 # Classification of CIFAR 100
PART_IDX = 1
NUM_ENSEMBLE = 16
NUM_EMBED = [32,1024] #Number of vectors in the codebook.
NUM_PARTS = 1 # Split the vectors in the codebook into ### parts.
COMMITMENT_W = 0.1 # Weight to the VQ_LOSS

#####################################################################
############################## DATASET ##############################
#####################################################################

def get_test_transforms():
    test_transform = transforms.Compose(
                [transforms.ToTensor(),
      		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    return test_transform

def get_train_transforms():
    transform = transforms.Compose(
                [transforms.RandomCrop(32, padding=4),
                 transforms.RandomHorizontalFlip(),
                 transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    return transform

train_transform = get_train_transforms()
test_transform = get_test_transforms()

ssl._create_default_https_context = ssl._create_unverified_context
path = "/tmp/cifar100"
trainset = datasets.CIFAR100(root = path, train=True, download=True, transform=train_transform)
testset = datasets.CIFAR100(root = path, train=False, download=True, transform=test_transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

#####################################################################
########################### ARCHITECTURE ############################
#####################################################################

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view((input.shape[0], -1))



def weight_noise(m):
    ## Reset all the parameters of the new 'Decoder'.
    ## For creating an ensembles of decoders.
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.weight.data = m.weight.data + torch.randn((m.weight).shape)*0.05
        if m.bias is not None:
            m.bias.data = m.bias.data + torch.randn((m.bias.shape))*0.02


def SplitNet(width=1, pretrained=True, num_classes=1000, stop_layer=4,decoder_copies=1):
    if width != 1:
        pretrained = False

    encoder_layers = []
    decoder_layers = []
    EncDec_dict = dict(encoder=[], decoders=[])
    inverted_residual_setting=[[1, 16, 1, 1],[6, 24, 2, 1],[6, 32, 3, 1],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 1],[6, 320, 1, 1]]
    num_channels_per_layer = [32, 16, 24, 24, 32, 32, 32, 64, 64, 64, 64, 96, 96, 96, 160, 160, 160, 320, 1280]

    mobilenetv2 = models.mobilenet_v2(pretrained=pretrained, num_classes=1000,width_mult=width,inverted_residual_setting=inverted_residual_setting)
    res_stop = 5
    for layer_idx, l in enumerate(mobilenetv2.features):
        if layer_idx <= res_stop:
            encoder_layers.append(l)
        else:
            decoder_layers.append(l)

    dropout = nn.Dropout(0.2,inplace=True)
    fc = nn.Linear(in_features=1280,out_features=num_classes,bias=True)
    classifier = nn.Sequential(dropout,fc)
    pool = nn.AdaptiveAvgPool2d(1)
    decoder_layers.append(pool)
    decoder_layers.append(Flatten())
    decoder_layers.append(classifier)

    EncDec_dict['encoder'] = nn.Sequential(*encoder_layers)
    EncDec_dict['decoders'] = [nn.Sequential(*decoder_layers)] # listed for a list of decoders

    ## Creating a list of different Decoders
    while decoder_copies > 1:
        new_decoder = copy.deepcopy(nn.Sequential(*decoder_layers))
        # new_decoder.apply(weight_noise)
        EncDec_dict['decoders'].append(new_decoder)
        decoder_copies -= 1

    return EncDec_dict


#####################################################################
############################# QUANTIZER #############################
#####################################################################

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, codebook_size, commitment, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._num_vectors = codebook_size
        self._size_vectors = num_embeddings

        self._embedding = nn.Embedding(self._num_vectors, self._size_vectors)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment

        self.register_buffer('_ema_cluster_size', torch.zeros(self._num_vectors))
        self._ema_w = nn.Parameter(torch.Tensor(self._num_vectors, self._size_vectors))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        # inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._size_vectors)

        # Calculate distances
        distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
                     + torch.sum(self._embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_vectors, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                    (self._ema_cluster_size + self._epsilon)
                    / (n + self._num_vectors * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.contiguous(), encodings



#####################################################################
############################### MODEL ###############################
#####################################################################

class SideFuncs(nn.Module):
    def __init__(self, encoder, decoder, primary_loss, n_embed, decay=0.8, commitment=1., eps=1e-5,
                 skip_quant=False, learning_rate=1e-3,training=True):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps
        self.primary_loss = primary_loss
        self.commitment_w = commitment
        dummy_input = torch.zeros((1, 3, 100, 100)) # Check number of channels the encoder outputs
        self.quant_dim = encoder(dummy_input).shape[1]
        self.quantizer = VectorQuantizerEMA(num_embeddings=self.quant_dim,
                                            codebook_size=self.n_embed,  # size of the dictionary
                                            commitment=1.0,
                                            # the weight on the commitment loss (==1 cause we want control))
                                            decay=self.decay)
                                            # the exponential moving average decay, lower means the dictionary will change faster

        self.skip_quant = skip_quant
        self.learning_rate = learning_rate


    def encode(self, x):
        z_e = self.encoder(x)
        z_e = z_e.view((z_e.shape[0], z_e.shape[2], z_e.shape[3], z_e.shape[1]))
        return z_e

    def quantize(self,z_e):
        if not self.skip_quant:
            commit_loss, z_q, indices = self.quantizer(z_e)
        else:
            z_q, indices, commit_loss = z_e, None, 0
        return z_q, indices, commit_loss


    def decode(self, z):
        predictions = []
        for i, decoder in enumerate(self.decoder):
            if i == 0:
                predictions.append(decoder(z[0]))
            else:
                predictions.append(decoder(z[1]))
        return predictions

    def calculate_prime_loss(self, y_hat_list, y):
        loss = 0
        for y_hat in y_hat_list:
            loss += self.primary_loss(y_hat, y)
        return loss / len(y_hat_list)

    def ensemble_calculator(self, preds_list):
        return torch.mean(torch.stack(preds_list), axis=0)

    def accuracy(self,y,y_pred,ensemble_y_pred):
        ens_pred = torch.max(ensemble_y_pred.data, 1)[1]
        batch_ens_corr = (ens_pred == y).sum()
        predicted = []
        batch_corr = []
        for vec in range(len(y_pred)):
            predicted.append(torch.max(y_pred[vec].data, 1)[1])
            batch_corr.append((predicted[vec] == y).sum())
        return batch_corr, batch_ens_corr



class NeuraQuantModel(SideFuncs):
    def __init__(self, encoder, decoder, primary_loss, n_embed=1024, decay=0.8, commitment=1., eps=1e-5,
                 skip_quant=False, learning_rate=1e-3):
        super().__init__(encoder, decoder, primary_loss,n_embed,decay, commitment, eps,
                 skip_quant, learning_rate)
        self.encoder = encoder
        self.decoder = nn.ModuleList(self.decoder)

    def process_batch(self,batch):
        x, y = batch
        z_e = self.encoder(x)
        z_e_for_quant = z_e.view((z_e.shape[0], z_e.shape[2], z_e.shape[3], z_e.shape[1]))
        commit_loss, z_q, indices = self.quantizer(z_e_for_quant)
        z_q = z_q.view((z_q.shape[0], z_q.shape[3], z_q.shape[1], z_q.shape[2]))
        z_q = [z_e,z_q]
        y_hat = self.decode(z_q)
        ensemble_y_hat = self.ensemble_calculator(y_hat)
        batch_acc, batch_acc_ensemble = self.accuracy(y, y_hat, ensemble_y_hat)
        prime_loss = self.calculate_prime_loss(y_hat,y)
        result_dict = {'loss': prime_loss + commit_loss, 'preds': y_hat, 'gts': y}
        return result_dict, batch_acc, batch_acc_ensemble, y_hat



#####################################################################
############################# TRAINING ##############################
#####################################################################

train_losses = []
test_losses = []
test_correct = []
epoch_acc = []
ensemble_epoch_acc = []
dec_acc = []
np.random.seed(100)


# Number of parameters
def get_n_params(model):
    pp = 0
    for p in list(model.parameters()):
        nn = 1
        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp


def train(model,EncDec_dict,optimizer):
    start_time = time.time()
    val_acc = []
    for epc in range(EPOCHS):
        trn_corr = [0]*len(EncDec_dict['decoders'])
        ensemble_corr = 0
        losses = 0
        for batch_num, (Train, Labels) in enumerate(trainloader):
            batch_num += 1
            batch = (Train.to(device), Labels.to(device))
            result_dict, batch_acc, batch_acc_ensemble, _ = model.process_batch(batch)
            loss = result_dict['loss']
            losses += loss.item()
            ensemble_corr += batch_acc_ensemble
            for num in range(len(batch_acc)):
                trn_corr[num] += batch_acc[num]

    #        if grad_clip:
    #           nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            # Update parameters
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if batch_num % 100 == 0:
                print(f'epoch: {epc+1:2}  batch: {batch_num:2} [{BATCH_SIZE * batch_num:6}/{len(trainset)}]  total loss: {loss.item():10.8f}  \
                time = [{(time.time() - start_time)/60}] minutes')

        # scheduler.step()
        ### Accuracy ###
        loss = losses/batch_num
        train_losses.append(loss)

        ensemble_epoch_acc.append((ensemble_corr.item()/(BATCH_SIZE*batch_num))*100)

        for acc in range(len(trn_corr)):
            dec_acc[acc].append((trn_corr[acc].item()/(BATCH_SIZE*batch_num))*100)

        num_ens_correct=0
        num_dec_correct = [0]*len(EncDec_dict['decoders'])
        model.eval()
        test_losses_val = 0
        with torch.no_grad():
            for b, (X_test, y_test) in enumerate(testloader):
                # Apply the model
                b += 1
                batch = (X_test.to(device), y_test.to(device))
                result_dict, test_batch_acc, test_batch_acc_ensemble, y_hat = model.process_batch(batch)
                test_loss = result_dict['loss']
                test_losses_val += test_loss.item()
                num_ens_correct += test_batch_acc_ensemble
                for lss in range(len(test_batch_acc)):
                    num_dec_correct[lss] += test_batch_acc[lss]

            test_losses.append(test_losses_val/b)
        print(f'Train Ensemble Accuracy at epoch {epc + 1} is {100*ensemble_corr/len(trainset)}%')
        print(f'Validation Ensemble Accuracy at epoch {epc+1} is {100*num_ens_correct/len(testset)}%')
        val_acc.append(100*num_ens_correct/len(testset))


        model.train()
    duration = time.time() - start_time
    print(f'Training took: {duration / 3600} hours')
    return max(val_acc), y_hat


num_parameters = []
accuracy = []



EncDec_dict = SplitNet(width=ARCH_WIDTH,
                        pretrained=True,
                        num_classes=NUM_CLASSES,
                        stop_layer=PART_IDX,
                        decoder_copies=8)

criterion = nn.CrossEntropyLoss()

model = NeuraQuantModel(encoder=EncDec_dict['encoder'],
                        decoder=EncDec_dict['decoders'],
                        primary_loss=criterion,
                        n_embed=256,
                        commitment=COMMITMENT_W)

model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


for dec in range(len(EncDec_dict['decoders'])):
    dec_acc.append([])


def train_ens(ens,model,EncDec_dict,optimizer):
    params = get_n_params(model)
    num_parameters.append(params)
    acc, y_hat = train(model,EncDec_dict,optimizer)
    torch.save(y_hat,"y_hat.pt")
    accuracy.append(acc)
    print(f'Number of parameters for {ens + 1} users is : {params}')
    print(f'Accuracy of {ens + 1} users is : {acc}')
    return y_hat

y_hat = train_ens(8,model,EncDec_dict,optimizer)


print(num_parameters)
print(accuracy)



Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /tmp/cifar100/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:04<00:00, 42083852.12it/s]


Extracting /tmp/cifar100/cifar-100-python.tar.gz to /tmp/cifar100
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 164MB/s]


epoch:  1  batch: 100 [  6400/50000]  total loss: 3.09962845                  time = [0.7358582059542338] minutes
epoch:  1  batch: 200 [ 12800/50000]  total loss: 2.60135627                  time = [1.356647789478302] minutes
epoch:  1  batch: 300 [ 19200/50000]  total loss: 2.29938149                  time = [1.9890548467636109] minutes
epoch:  1  batch: 400 [ 25600/50000]  total loss: 2.33984923                  time = [2.631840443611145] minutes
epoch:  1  batch: 500 [ 32000/50000]  total loss: 2.30282784                  time = [3.2686692476272583] minutes
epoch:  1  batch: 600 [ 38400/50000]  total loss: 2.28998470                  time = [3.908490558465322] minutes
epoch:  1  batch: 700 [ 44800/50000]  total loss: 1.71812284                  time = [4.5471675554911295] minutes
Train Ensemble Accuracy at epoch 1 is 39.33599853515625%
Validation Ensemble Accuracy at epoch 1 is 50.939998626708984%
epoch:  2  batch: 100 [  6400/50000]  total loss: 2.02938581                  time = 

In [91]:
from google.colab import drive
drive.mount('/content/drive')
model_path = "/content/drive/MyDrive/model_ks_gpu.pth"

device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.save(model.to(device).state_dict(), model_path)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [109]:
model.load_state_dict(torch.load(model_path))

def get_val_acc(testloader):
  model.eval()
  user_single_acc = [0]*8
  y_hats = [[],[],[],[],[],[],[],[]]
  trues = []
  with torch.no_grad():
      for b, (X_test, y_test) in enumerate(testloader):
          # Apply the model
          b += 1
          batch = (X_test.to(device), y_test.to(device))
          _, test_batch_acc, _, y_hat = model.process_batch(batch)
          trues.append(y_test)
          for lss in range(len(test_batch_acc)):
              user_single_acc[lss] += test_batch_acc[lss]
              y_hats[lss].append(y_hat[lss])

  user_single_acc = [100*acc / (b*BATCH_SIZE) for acc in user_single_acc]
  return user_single_acc, y_hats, trues


def calc_relative_acc(accuracies, sharpening):
    relative_accuracies = []
    sharped_accuracies = [acc**(sharpening) for acc in accuracies]
    for i in range(len(sharped_accuracies)):
      before_i = sharped_accuracies[:i]
      after_i = sharped_accuracies[i + 1:]
      rest = before_i + after_i

      single_sharp_acc = sharped_accuracies[i] / sum(rest)
      relative_accuracies.append(single_sharp_acc)
    return relative_accuracies


def get_weighted_val_acc(preds, true):
  ens_pred = torch.max(preds.data, 2)[1]
  num_corr = 0
  for i in range(ens_pred.shape[0]):
    num_corr += (ens_pred[i] == true[i]).sum()
  val_acc = 100*num_corr/(len(testset)-16)
  return val_acc


In [110]:
def validate_with_k(weight_factor, sharpening=8):
    user_single_acc, y_hats, trues = get_val_acc(testloader)
    relative_accuracies = calc_relative_acc(user_single_acc, sharpening)

    first_user_acc = relative_accuracies[0]
    rest_users_acc = relative_accuracies[1:]
    first_user = y_hats[0]
    rest_users = y_hats[1:]

    alpha_1 = first_user_acc / (first_user_acc + weight_factor * sum(rest_users_acc))
    alpha_rest = [(weight_factor * rest_user_acc) / (first_user_acc + weight_factor * sum(rest_users_acc)) for
                  rest_user_acc in rest_users_acc]

    result_list = [alpha_rest[0] * tensor for j, tensor in enumerate(rest_users[0]) if j < 156]
    rest_weighted_users = result_list
    # Calculate the dot product
    summed_list = []
    for i in range(1,len(rest_users) - 1):
        result_list = [alpha_rest[i] * tensor for j, tensor in enumerate(rest_users[i]) if j < 156]

        for tensor1, tensor2 in zip(rest_weighted_users, result_list):
            summed_list.append(tensor1 + tensor2)
        rest_weighted_users = summed_list
        summed_list = []

    weightning_preds = []
    for tensor1, tensor2 in zip(first_user[:156], rest_weighted_users):
        weightning_preds.append(alpha_1 * tensor1 + tensor2)

    val_acc = get_weighted_val_acc(torch.stack(weightning_preds).to(device), torch.stack(trues[:156]).to(device))
    print(val_acc)
    return val_acc



In [112]:
sqrt_8 = np.sqrt(8)
weight_factors = [1/(8*sqrt_8),1/(7*sqrt_8),1/(6*sqrt_8),1/(5*sqrt_8),1/(4*sqrt_8),1/(3*sqrt_8),1/(2*sqrt_8),1/(sqrt_8),2/(sqrt_8),3/(sqrt_8),4/(sqrt_8),5/(sqrt_8),6/(sqrt_8)
                  ,7/(sqrt_8),8/(sqrt_8),9/(sqrt_8)]

accuracies = []
for i in range(len(weight_factors)):
    acc = validate_with_k(weight_factors[i])
    accuracies.append(acc)

tensor(73.2873, device='cuda:0')
tensor(73.4375, device='cuda:0')
tensor(73.6278, device='cuda:0')
tensor(73.8682, device='cuda:0')
tensor(74.0885, device='cuda:0')
tensor(74.5393, device='cuda:0')
tensor(75.2204, device='cuda:0')
tensor(75.8213, device='cuda:0')
tensor(75.7212, device='cuda:0')
tensor(75.5208, device='cuda:0')
tensor(75.4908, device='cuda:0')
tensor(75.4107, device='cuda:0')
tensor(75.3305, device='cuda:0')
tensor(75.3305, device='cuda:0')
tensor(75.2704, device='cuda:0')
tensor(75.2204, device='cuda:0')


In [113]:
print(accuracies)

[tensor(73.2873, device='cuda:0'), tensor(73.4375, device='cuda:0'), tensor(73.6278, device='cuda:0'), tensor(73.8682, device='cuda:0'), tensor(74.0885, device='cuda:0'), tensor(74.5393, device='cuda:0'), tensor(75.2204, device='cuda:0'), tensor(75.8213, device='cuda:0'), tensor(75.7212, device='cuda:0'), tensor(75.5208, device='cuda:0'), tensor(75.4908, device='cuda:0'), tensor(75.4107, device='cuda:0'), tensor(75.3305, device='cuda:0'), tensor(75.3305, device='cuda:0'), tensor(75.2704, device='cuda:0'), tensor(75.2204, device='cuda:0')]
