<a href="https://colab.research.google.com/github/seglass5/PartIIIProject/blob/master/Part_III_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Install Required Packages


In [None]:
# Python packages
!pip install tensorboardX
!pip install pot
!pip install -U networkx
!pip install matplotlib==3.1.1
!pip install torchvision

# Custom packages from github to allow RandWire to work, and FLOPs counter to work
!pip install git+https://github.com/JiaminRen/CVdevKit.git
!pip install git+git://github.com/sovrasov/flops-counter.pytorch.git@64d38fd47cb0795437056745c64a987d944c1885
!pip install igraph
!pip install ptflops

# 2. Import Packages

In [None]:
import torch
import torchvision
import torchvision.models as models
from ptflops import get_model_complexity_info
import tensorboardX
import yaml
import CVdevKit
import networkx as nx
import matplotlib.pyplot as plt
%matplotlib inline
import ot
import importlib
import numpy as np
import gym
import sys
import csv
import torch
from torch import nn
from torch import optim
import seaborn as sns
import pandas as pd
import json
from google.colab import files


In [None]:
import random
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 3. Mount Drive and Import Packages and Functions

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 4. Train model of Randomly Wired Neural Networks

In [None]:
def remove_isolated_nodes(graph):
  isolated_nodes = []
  for node in list(graph.nodes()):
    if(graph.degree[node] == 0):
      isolated_nodes.append(node)
 # print(isolated_nodes)
  graph.remove_nodes_from(isolated_nodes)
  return graph

In [None]:
import networkx as nx
import collections
import matplotlib.pyplot as plt

Node = collections.namedtuple('Node', ['id', 'inputs', 'type'])

def get_graph_info(graph):

  input_nodes = []
  output_nodes = []
  Nodes = []
  for node in range(graph.number_of_nodes()):
    tmp = list(graph.neighbors(node))
    tmp.sort()

    type = -1
    if node < tmp[0]:
      input_nodes.append(node)
      type = 0
    if node > tmp[-1]:
      output_nodes.append(node)
      type = 1
    Nodes.append(Node(node, [n for n in tmp if n < node], type))

  return Nodes, input_nodes, output_nodes

def build_graph(Nodes, cfg ):
  if cfg.GRAPH_MODEL == 'ER':
    return nx.random_graphs.erdos_renyi_graph(Nodes, cfg.ER_P, cfg.RND_SEED)
  elif cfg.GRAPH_MODEL == 'BA':
    return nx.random_graphs.barabasi_albert_graph(Nodes, cfg.BA_M,cfg.RND_SEED)
  elif cfg.GRAPH_MODEL == 'WS':
    return nx.random_graphs.connected_watts_strogatz_graph(Nodes, cfg.WS_K, cfg.WS_P, tries=200, seed=cfg.RND_SEED)

def save_graph(graph, path):
  nx.write_yaml(graph, path)

def save_graphml(graph,path):
  nx.write_graphml(graph,path)

def load_graph(path):
  return nx.read_yaml(path)

def load_graphml(path):
  return nx.read_graphml(path, node_type = int).copy()



In [None]:
# referred to JiaminRen's implementation
# https://github.com/JiaminRen/RandWireNN

class conv_unit(nn.Module):
    def __init__(self, nin, nout, stride):
        super(conv_unit, self).__init__()
        # print("conv_unit : nin : {}, nout: {}, stride :{}".format(nin,nout, stride))
        self.depthwise_separable_conv_3x3 = nn.Conv2d(nin, nin, kernel_size=3, stride=stride, padding=1, groups=nin)
        self.pointwise_conv_1x1 = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise_separable_conv_3x3(x)
        out = self.pointwise_conv_1x1(out)
        return out

class Triplet_unit(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1):
        super(Triplet_unit, self).__init__()
        self.relu = nn.ReLU()
        self.conv = conv_unit(inplanes, outplanes, stride)
        self.bn = nn.BatchNorm2d(outplanes)

    def forward(self, x):
        out = self.relu(x)
        out = self.conv(out)
        out = self.bn(out)
        return out

class Node_OP(nn.Module):
    def __init__(self, Node, inplanes, outplanes):
        super(Node_OP, self).__init__()
        self.is_input_node = Node.type == 0
        self.input_nums = len(Node.inputs)
        if self.input_nums > 1:
            self.mean_weight = nn.Parameter(torch.ones(self.input_nums))
            self.sigmoid = nn.Sigmoid()
        if self.is_input_node:
            self.conv = Triplet_unit(inplanes, outplanes, stride=2)
        else:
            self.conv = Triplet_unit(outplanes, outplanes, stride=1)

    def forward(self, *input):
        if self.input_nums > 1:
            out = self.sigmoid(self.mean_weight[0]) * input[0]
        for i in range(1, self.input_nums):
            out = out + self.sigmoid(self.mean_weight[i]) * input[i]
        else:
            out = input[0]
            out = self.conv(out)
        return out

class StageBlock(nn.Module):
    def __init__(self, graph, inplanes, outplanes):
        super(StageBlock, self).__init__()
        self.nodes, self.input_nodes, self.output_nodes = get_graph_info(graph)
        self.nodeop  = nn.ModuleList()
        for node in self.nodes:
            self.nodeop.append(Node_OP(node, inplanes, outplanes))

    def forward(self, x):
        results = {}
        for id in self.input_nodes:
            results[id] = self.nodeop[id](x)
        for id, node in enumerate(self.nodes):
            if id not in self.input_nodes:
                results[id] = self.nodeop[id](*[results[_id] for _id in node.inputs])
        result = results[self.output_nodes[0]]
        for idx, id in enumerate(self.output_nodes):
            if idx > 0:
                result = result + results[id]
        result = result / len(self.output_nodes)
        return result

class Net(nn.Module):
    def __init__(self, cfg,measure):
        super(Net, self).__init__()
        # for image color scale
        #color = cfg.NN.COLOR
        color=3
        #print(color)
        N = cfg.NN.NODES
        size = cfg.NN.IMG_SIZE
        num_classes = cfg.NN.NUM_CLASSES


        if (cfg["USE_PRUNED_GRAPH"] == True):
            %cd /content/drive/MyDrive/PRUNING/PRUNED/WS/
            global pval,kval,seed
            if (cfg['NN']['REGIME'] == "SMALL"):
              graph2 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/SMALL/output_bt/{}/conv2.graphml'.format(kval,pval,seed,measure))
              graph3 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/SMALL/output_bt/{}/conv3.graphml'.format(kval,pval,seed,measure))
              graph4 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/SMALL/output_bt/{}/conv4.graphml'.format(kval,pval,seed,measure))
              graph5 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/SMALL/output_bt/{}/conv5.graphml'.format(kval,pval,seed,measure))
            if (cfg['NN']['REGIME'] == "REGULAR"):
              graph2 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/REGULAR/output_bt/{}/conv2.graphml'.format(kval,pval,seed,measure))
              graph3 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/REGULAR/output_bt/{}/conv3.graphml'.format(kval,pval,seed,measure))
              graph4 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/REGULAR/output_bt/{}/conv4.graphml'.format(kval,pval,seed,measure))
              graph5 = load_graphml('./WS_K_{}/WS_P_{}/seed_{}/REGULAR/output_bt/{}/conv5.graphml'.format(kval,pval,seed,measure))

        elif cfg.MAKE_GRAPH:

            %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS

            graph2 = nx.convert_node_labels_to_integers(remove_isolated_nodes(build_graph(N//2, cfg)))
            graph3 = nx.convert_node_labels_to_integers(remove_isolated_nodes(build_graph(N, cfg)))
            graph4 = nx.convert_node_labels_to_integers(remove_isolated_nodes(build_graph(N, cfg)))
            graph5 = nx.convert_node_labels_to_integers(remove_isolated_nodes(build_graph(N, cfg)))

            global net_stats
            net_stats['conv2_nodes']=graph2.number_of_nodes()
            net_stats['conv2_edges']=graph2.number_of_edges()
            net_stats['conv3_nodes']=graph3.number_of_nodes()
            net_stats['conv3_edges']=graph3.number_of_edges()

            if (cfg['NN']['REGIME'] == "SMALL"):
              save_graphml(graph2, './WS_K_{}/WS_P_{}/seed_{}/SMALL/conv2_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph3, './WS_K_{}/WS_P_{}/seed_{}/SMALL/conv3_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph4, './WS_K_{}/WS_P_{}/seed_{}/SMALL/conv4_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph5, './WS_K_{}/WS_P_{}/seed_{}/SMALL/conv5_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
            if (cfg['NN']['REGIME'] == "REGULAR"):
              save_graphml(graph2, './WS_K_{}/WS_P_{}/seed_{}/REGULAR/conv2_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph3, './WS_K_{}/WS_P_{}/seed_{}/REGULAR/conv3_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph4, './WS_K_{}/WS_P_{}/seed_{}/REGULAR/conv4_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
              save_graphml(graph5, './WS_K_{}/WS_P_{}/seed_{}/REGULAR/conv5_{}.graphml'.format(kval,pval,seed,cfg['RND_SEED']))
        else:
            graph2 = load_graph('./output/graph/conv2.yaml')
            graph3 = load_graph('./output/graph/conv3.yaml')
            graph4 = load_graph('./output/graph/conv4.yaml')
            graph5 = load_graph('./output/graph/conv5.yaml')

        if cfg.NN.REGIME == "SMALL":
            print('small regime')
            C = 78
            self.conv1 =  nn.Sequential(
                conv_unit(color, C//2, 2),
                nn.BatchNorm2d(C//2)
                )
            self.conv2 = Triplet_unit(C//2, C)
            self.conv3 = StageBlock(graph3, C, C)
            self.conv4 = StageBlock(graph4, C, 2*C)
            self.conv5 = StageBlock(graph5, 2*C, 4*C)
            self.classifier = nn.Sequential(
                    nn.ReLU(True),
                    nn.Conv2d(4*C, 1280, kernel_size=1),
                    nn.BatchNorm2d(1280),
                    nn.ReLU(True),
                    nn.AvgPool2d(size//16, stride=1),
                )
            self.fc = nn.Linear(1280, num_classes)

        if cfg.NN.REGIME == "REGULAR":
            print('regular regime')
            C = 109
            self.conv1 =  nn.Sequential(
                conv_unit(color, C//2, 2),
                nn.BatchNorm2d(C//2)
                )
            self.conv2 = StageBlock(graph2, C//2,  C)
            self.conv3 = StageBlock(graph3, C,   2*C)
            self.conv4 = StageBlock(graph4, 2*C, 4*C)
            self.conv5 = StageBlock(graph5, 4*C, 8*C)
            self.classifier = nn.Sequential(
                    nn.ReLU(True),
                    nn.Dropout(0.2),
                    nn.Conv2d(8*C, 1280, kernel_size=1),
                    nn.BatchNorm2d(1280),
                    nn.ReLU(True),
                    nn.Dropout(0.2),
                    nn.AvgPool2d(size//32, stride=1),
                )
            self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        # print(x.shape)
        x = self.conv1(x)
        # print(x.shape)
        x = self.conv2(x)
        # print(x.shape)
        x = self.conv3(x)
        # print(x.shape)
        x = self.conv4(x)
        # print(x.shape)
        x = self.conv5(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [None]:
import os
os.chdir('/content/drive/MyDrive/Codes/')
import torch
from utils.config_helpers import merge_configs
import time
import networkx as nx

def get_configuration():
    # load configs for base network and data set
    from RandWireNN_config import cfg as network_cfg
    from utils.configs.ImageNet_config import cfg as dataset_cfg
    return merge_configs([network_cfg, dataset_cfg])


In [None]:
def weights_init(m):
    torch.manual_seed(1)
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data,mean = 0, std = 1/16)
        nn.init.constant_(m.bias.data,0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data,mean=0, std=1/16)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight.data, mean = 0, std = 1/16)
        nn.init.constant_(m.bias.data,0)

In [None]:
import torch
import torch.optim as optim
import time
import os, sys

def train(train_loader, model, criterion, optimizer, epoch, cfg):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1,
                             top5, prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    i=0

    for input,target in train_loader:

        data_time.update(time.time() - end)

        input = input.to(cfg.DEVICE)
        target = target.to(cfg.DEVICE)

        # compute output
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 2))
        #acc1 ,acc= accuracy(output, target, topk=(1,1))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % cfg.PRINT_FREQ == 0:
            progress.print(i)
            if cfg.VISDOM:
                cfg.vis.line(X=torch.Tensor([epoch+(i/len(train_loader))]).unsqueeze(0).cpu(),
                              Y=torch.Tensor([loss]).unsqueeze(0).cpu(),
                              env='torch',win=cfg.loss_window,name='train_loss',update='append')
        i+=1


def validate(val_loader, model, criterion, cfg):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        i=0
        for input,target in val_loader:

          input = input.to(cfg.DEVICE)
          target = target.to(cfg.DEVICE)

          # compute output
          output = model(input)
          loss = criterion(output, target)

          #print(output," ",target)
          # measure accuracy and record loss
          acc1, acc5 = accuracy(output, target, topk=(1, 2))
          losses.update(loss.item(), input.size(0))
          top1.update(acc1[0], input.size(0))
          top5.update(acc5[0], input.size(0))

          # measure elapsed time
          batch_time.update(time.time() - end)
          end = time.time()

          if i % cfg.PRINT_FREQ == 0:
              progress.print(i)

      # TODO: this should also be done with the ProgressMeter

          i+=1
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return losses.avg, top1.avg

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, *meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def print(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():

        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1,-1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0,keepdim = True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def prepare(cfg, use_arg_parser=True):
    if not os.path.isdir("./output/"):
        os.mkdir("./output/")
    if not os.path.isdir("./output/model"):
        os.mkdir("./output/model")
    if not os.path.isdir("./output/graph"):
        os.mkdir("./output/graph")
    if not cfg.TEST_MODE:
        if cfg.VISDOM:
            cfg.loss_window = cfg.vis.line(
                        Y=torch.zeros((1)).cpu(),
                        X=torch.zeros((1)).cpu(),env='torch',
                        opts=dict(xlabel='epoch',ylabel='Loss',
                                    title=cfg.DATASET_NAME+"_"
                                    +time.strftime("%m/%d %H:%M", time.localtime()),
                        legend=['train_loss','val_loss']))

In [None]:
def performance_measure(val_loader, model,bottom_val,top_val,measure,cfg):
  from sklearn.metrics import confusion_matrix
  import pandas as pd
  global seed, pval,kval
  nb_classes = 2
  if(cfg['USE_PRUNED_GRAPH']):
      %cd /content/drive/MyDrive/PRUNING/PRUNED/WS
    # Train
  else:
    %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS

  # Initialize the prediction and label lists(tensors)


  if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/".format(kval,pval,num,cfg['NN']['REGIME'],measure))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/".format(kval,pval,num,cfg['NN']['REGIME'],measure))

  count = 0

  Thresh_list=[[0.9,0.1]]
  dic={}
  cf={}
  for k in Thresh_list:
    dic[k[1]]={}
    if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/Thresh_{}/".format(kval,pval,num,cfg['NN']['REGIME'],measure,k[1]))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/Thresh_{}/".format(kval,pval,num,cfg['NN']['REGIME'],measure,k[1]))


  for k in Thresh_list:
    dic[k[1]]['TN']=0
    dic[k[1]]['TP']=0
    dic[k[1]]['FN']=0
    dic[k[1]]['FP']=0
    cf[k[1]]=[]
  with torch.no_grad():

    prediction_list=[]
    actual_class_list=[]
    one_count=0
    zero_count=0
    y_train=[]
    y_train.clear()

    pred=[]
    pred.clear()

    for inputs, classes in val_loader:

        inputs = inputs.to(cfg.DEVICE)
        classes = classes.to(cfg.DEVICE)

        for i,x in enumerate(classes):

          if x.item()==1:
            one_count+=1
          else:
            zero_count+=1
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        probability = torch.nn.functional.softmax(outputs, 1)
        probabilities=probability.to(cfg.DEVICE)
        threshold = torch.tensor(k).unsqueeze(0)
        thresholds=threshold.to(cfg.DEVICE)

        predictions = probabilities >= thresholds
        actual_class_list.clear()

        for i,x in enumerate(classes):
          actual_class_list.append(x.item())
          y_train.append(x.item())

        temp_c=0
        for k in Thresh_list:
          #print("Length",len(prediction_list)," ",len(actual_class_list))
          prediction_list.clear()
          access_tensor=0
          for i,x in enumerate(probabilities):
            for t in x:
              if round(t.item(),2) >k[0]:
                prediction_list.append(0)
                cf[k[1]].append(0)
              else:
                prediction_list.append(1)
                cf[k[1]].append(1)
              break

            for t in x:
              if k[0]==0.9:
                if access_tensor==0:
                  access_tensor+=1
                  continue
                else:
                  pred.append(round(t.item(),2))
                  access_tensor=0

         # print(len(prediction_list)," ",len(actual_class_list))
          for i in range(len(prediction_list)):
            #print(i)
            if prediction_list[i] ==0 and actual_class_list[i]==0:
              dic[k[1]]['TN']+=1
            if prediction_list[i] ==0 and actual_class_list[i]==1:
              dic[k[1]]['FN']+=1
            if prediction_list[i] ==1 and actual_class_list[i]==0:
              dic[k[1]]['FP']+=1
            if prediction_list[i] ==1 and actual_class_list[i]==1:
              dic[k[1]]['TP']+=1
          #else:
            # Nothing=0
  #print(dic)
  for k in Thresh_list:

    dic[k[1]]['TPR']=dic[k[1]]['TP']/(dic[k[1]]['TP']+dic[k[1]]['FN'])
    dic[k[1]]['TNR']=dic[k[1]]['TN']/(dic[k[1]]['TN']+dic[k[1]]['FP'])
    dic[k[1]]['FPR']=dic[k[1]]['FP']/(dic[k[1]]['TN']+dic[k[1]]['FP'])
    dic[k[1]]['FNR']=dic[k[1]]['FN']/(dic[k[1]]['TP']+dic[k[1]]['FN'])
    dic[k[1]]['Accuracy']=(dic[k[1]]['TP']+dic[k[1]]['TN'])/(dic[k[1]]['TP']+dic[k[1]]['FN']+dic[k[1]]['TN']+dic[k[1]]['FP'])
    dic[k[1]]['Precision']=dic[k[1]]['TP']/(dic[k[1]]['TP']+dic[k[1]]['FP'])
    dic[k[1]]['Recal/Sensitivity']=dic[k[1]]['TP']/(dic[k[1]]['TP']+dic[k[1]]['FN'])
    dic[k[1]]['Specifity']=dic[k[1]]['TN']/(dic[k[1]]['TN']+dic[k[1]]['FP'])
    dic[k[1]]['F1-Score']=(2*dic[k[1]]['Precision']*dic[k[1]]['Recal/Sensitivity'])/(dic[k[1]]['Precision']+dic[k[1]]['Recal/Sensitivity'])



  print(dic)
  print("Zero_Count",zero_count)
  print("One_Count",one_count)
  print(len(y_train)," ",len(pred))
  labels=['Threshold', 'TN', 'TP', 'FN', 'FP', 'TPR', 'TNR', 'FPR', 'FNR', 'Accuracy', 'Precision', 'Recall/Sensitivity', 'Specificity', 'F1-Score']

  df=pd.DataFrame(dic)
  data=df.swapaxes("rows", "columns").reset_index()
  data.columns=labels
  data['Seed']=pd.Series([seed]*data.shape[0])
  data['Graph-Model']=pd.Series(['WS']*data.shape[0])
  data['K-value']=pd.Series([kval]*data.shape[0])
  data['P-value']=pd.Series([pval]*data.shape[0])

  data.to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/PERFORMANCE_MEASURE_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val), encoding='utf-8', sep = '\t', header= True, index = False)


  import numpy as np
  from sklearn import metrics

  import matplotlib
  import matplotlib.pyplot as plt
  from matplotlib.ticker import MaxNLocator
  from sklearn.metrics import confusion_matrix
  import seaborn as sns

  matplotlib.rcParams['pdf.fonttype'] = 42
  matplotlib.rcParams['ps.fonttype'] = 42
  matplotlib.rcParams['font.family'] = "arial"
  plt.rcParams.update({'font.size': 10})

  plt.rcParams["figure.figsize"] = (3,3)

  for k in Thresh_list:
    Confusion_matrix=confusion_matrix(y_train, cf[k[1]])
    pd.Series(y_train).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/Thresh_{}/Confusion_matrix_act_val_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,k[1],bottom_val,top_val),index=False)
    pd.Series(cf[k[1]]).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/Thresh_{}/Confusion_matrix_pred_val_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,k[1],bottom_val,top_val),index=False)


    sns.heatmap(Confusion_matrix, annot=True, annot_kws={"size": 15},fmt="d",cbar=False, linewidths=.5, cmap="Blues",xticklabels=['Non Covid', 'Covid-19'], yticklabels=['Non Covid', 'Covid-19']) # font size
    plt.title('Confusion Matrix')
    plt.savefig("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/Thresh_{}/CONFUSION_MATRIX_X_{}_Y_{}.pdf".format(kval,pval,num,cfg['NN']['REGIME'],measure,k[1],bottom_val,top_val))
    plt.show()
    #print('\n')

  plt.rcParams["figure.figsize"] = (10,10)
  plt.rcParams.update({'font.size': 15})
  plt.rcParams["axes.grid"] = False
  plt.rcParams['axes.facecolor']='white'
  plt.rcParams['savefig.facecolor']='white'
  plt.rcParams['figure.facecolor'] = 'white'
  matplotlib.rc('axes',edgecolor='k')

  ax=plt.figure().gca()
  ax.spines['bottom'].set_linewidth(0.5)
  ax.spines['left'].set_linewidth(0.5)
  ax.tick_params(direction='out', length=0.5, width=2)

  y = np.array(y_train)
  prd = np.array(pred)
  fpr, tpr, thresholds = metrics.roc_curve(y, prd)
  pd.DataFrame(tpr).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/tpr_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val),index=False)
  pd.DataFrame(fpr).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/fpr_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val),index=False)
  RC=metrics.auc(fpr, tpr)
  #print("RC",RC)
  plt.title('Receiver Operating Characteristic')
  plt.plot(fpr, tpr,color='red',label='AUC = %0.4f'% RC)
  plt.legend(loc='lower right')
#  plt.plot([0,0],[0,1],'r--')
  plt.xlim([-0.001, 1.01])
  plt.ylim([-0.01, 1.01])
  plt.ylabel('True Positive Rate')
  plt.xlabel('False Positive Rate')
  plt.savefig("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/ROC_X_{}_Y_{}.pdf".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val))
  plt.show()

  precision, recall, thresholds =metrics.precision_recall_curve(y, prd)
  #plt.plot([0, 1],linestyle='--')

  pd.DataFrame(precision).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/precision_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val),index=False)
  pd.DataFrame(recall).to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/recall_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val),index=False)

  from sklearn.metrics import average_precision_score
  average_precision = average_precision_score(y, prd)
  plt.title('Precision Recall Curve')
  plt.plot(recall, precision,color='green',label= 'Average Precision = %0.2f'% average_precision)
  plt.legend(loc='lower left')
  #plt.plot([0,0],[1,0],'r--')
  plt.xlim([-0.001, 1.01])
  plt.ylim([-0.01, 1.01])
  plt.ylabel('Precision')
  plt.xlabel('Recall')
  plt.savefig("./WS_K_{}/WS_P_{}/seed_{}/{}/{}/PR_CURVE_X_{}_Y_{}.pdf".format(kval,pval,num,cfg['NN']['REGIME'],measure,bottom_val,top_val))
  plt.show()
  return dic[0.1]['Accuracy'],dic[0.1]['Recal/Sensitivity'],dic[0.1]['Specifity']

In [None]:
class set_environment():
  def __init__(self):
    self.acc_threshold = 0.8518
  def reset(self,seed):
    # Function to reset graphs to unpruned state
    %cd /content/drive/MyDrive/output/graph/
    conv_2_unpruned = nx.read_graphml('conv2_{}.graphml'.format(seed),node_type=int)
    conv_3_unpruned = nx.read_graphml('conv3_{}.graphml'.format(seed),node_type=int)
    conv_4_unpruned = nx.read_graphml('conv4_{}.graphml'.format(seed),node_type=int)
    conv_5_unpruned = nx.read_graphml('conv5_{}.graphml'.format(seed),node_type=int)
    # Non-iterative pruning, overwrite pruned graphs with unpruned at the start of each episode
    nx.write_graphml_lxml(conv_2_unpruned, 'conv2.graphml')
    nx.write_graphml_lxml(conv_3_unpruned, 'conv3.graphml')
    nx.write_graphml_lxml(conv_4_unpruned, 'conv4.graphml')
    nx.write_graphml_lxml(conv_5_unpruned, 'conv5.graphml')

  def step(self,cfg,seed,train_loader,val_loader,x=0,y=0,measure = None):
    if cfg['USE_PRUNED_GRAPH']:
      %cd /content/drive/MyDrive/PRUNING/PRUNED/WS
    # Train
    else:
      %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS
    
    model = Net(cfg,measure)
    model.apply(weights_init)

    if torch.cuda.device_count() > 1:
      print("Let's use", torch.cuda.device_count(), "GPUs!")
      model = torch.nn.DataParallel(model)
    model.to(cfg.DEVICE)

    criterion = torch.nn.CrossEntropyLoss().to(cfg.DEVICE)
    optimizer = torch.optim.SGD(model.parameters(),cfg.LEARNING_RATE, cfg.MOMENTUM, cfg.WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.EPOCH)

    if cfg.LOAD_TRAINED_MODEL:
      model.load_state_dict(torch.load(cfg.TRAINED_MODEL_LOAD_DIR))

    if not cfg.TEST_MODE:
      start = time.time()
      for epoch in range(cfg.EPOCH+1):
        #print("Epoch")
        train(train_loader, model, criterion, optimizer, epoch, cfg)
        scheduler.step()
        if epoch % cfg.VAL_FREQ == 0:
          val_loss, acc = validate(val_loader, model, criterion, cfg)

          #print("Accuracy",acc)
          # maxacc=max(maxacc,acc.cpu().detach().numpy())
          if cfg.VISDOM:
            cfg.vis.line(X=torch.Tensor([epoch+1]).unsqueeze(0).cpu(),Y=torch.Tensor([val_loss]).unsqueeze(0).cpu(),env='torch',win=cfg.loss_window,name='val_loss',update='append')
            cfg.vis.line(X=torch.Tensor([epoch+1]).unsqueeze(0).cpu(),Y=torch.Tensor([acc/100]).unsqueeze(0).cpu(),env='torch',win=cfg.loss_window,name='val_acc',update='append')
      end = (time.time() - start)//60
      #print("train time: {}D {}H {}M".format(end//1440, (end%1440)//60, end%60))

    lasavg,top1a=validate(val_loader, model, criterion, cfg)

    prune_accuracy,prune_sensitivity,prune_specificity=performance_measure(val_loader,model,x,y,measure,cfg)

    with torch.cuda.device(0):
      flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
      global net_stats
      net_stats['flops']=flops
      net_stats['params']=params
      df=pd.DataFrame(net_stats.items())
      data=df.swapaxes("rows", "columns")
      labels=['conv2_nodes','conv2_edges','conv3_nodes','conv3_edges','flops','params']
      data['Seed']=pd.Series([seed]*data.shape[0])
      data['Graph-Model']=pd.Series(['WS']*data.shape[0])
      data['K-value']=pd.Series([kval]*data.shape[0])
      data['P-value']=pd.Series([pval]*data.shape[0])

      data.to_csv( "./WS_K_{}/WS_P_{}/seed_{}/{}/{}/NET_STATS_X_{}_Y_{}.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure,x,y), encoding='utf-8', sep = '\t', header= True, index = False)

    #print("flops : ", flops, "    params :", params )

    return prune_accuracy,prune_sensitivity,prune_specificity, flops,params

In [None]:
def calc_unpruned_values(cfg,seed,pval,kval):
  %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS
  Performance  = pd.read_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/None/PERFORMANCE_MEASURE_X_0_Y_0.csv".format(kval,pval,num,cfg['NN']['REGIME']),sep = '\t')


  baseline_accuracy=float(round(Performance[Performance['Threshold']==0.1]['Accuracy'].copy(),4))
  #print(float(round(baseline_accuracy,4)))
  baseline_sensitivity=float(round(Performance[Performance['Threshold']==0.1]['Recall/Sensitivity'].copy(),4))
  #print(float(round(baseline_sensitivity,4)))
  baseline_specificity=float(round(Performance[Performance['Threshold']==0.1]['Specificity'].copy(),4))
  #print(float(round(baseline_specificity,4)))


  stats= pd.read_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/None/NET_STATS_X_0_Y_0.csv".format(kval,pval,num,cfg['NN']['REGIME'],num,pval),sep = '\t')
  conv2nodes=float(stats.at[1,'0'])
  conv2edges=float(stats.at[1,'1'])
  conv3nodes=float(stats.at[1,'2'])
  conv3edges=float(stats.at[1,'3'])
  flops_value=float(stats.at[1,'4'])
  param_value=float(stats.at[1,'5'])
  return baseline_accuracy,baseline_sensitivity,baseline_specificity,conv2nodes,conv2edges,conv3nodes,conv3edges,flops_value,param_value


In [None]:
import networkx as nx

def surgery(directed_graph, threshold,edgelist):

  edgelist = list(edgelist)
  #print("In surgery")
  #print("Length",len(edgelist))
  to_remove = edgelist[:threshold]
  #print(to_remove)
  #if threshold==0:
   # print("To Remove:",to_remove)
   # print("Number of Edges :",directed_graph.number_of_edges())
    #print("Number of nodes :",directed_graph.number_of_nodes())
  directed_graph.remove_edges_from(to_remove)
   # print("surgery Over")
 #  print("New_Length",)
  return directed_graph

In [None]:
import networkx as nx
# Function to return the list of nodes in a graph with no connected edges
def get_isolated_nodes(directed_graph):
  isolated_nodes = []
  for node in list(directed_graph.nodes()):
    if len(list(directed_graph.predecessors(node))) == 0:
        if len(list(directed_graph.successors(node))) == 0:
            isolated_nodes.append(node)
  return isolated_nodes

In [None]:
# Prepare a graph for feeding back into RandWire, ie. undirected remove input and output nodes

import networkx as nx
import numpy as np

#def prepare_new_graph(directed_graph1, isolated_nodes):
def prepare_new_graph(directed_graph1):

  global t1
  # Make new graph
  #print("In prepare_New_Graph")
  directed_graph = directed_graph1
    # Get label of highest node
  top_node = (len(list(directed_graph.nodes()))) - 1
  print("top node :",top_node)
  # Make sure input and output nodes will be removed even if they are not isolated
  #isolated_nodes.append(0) if 0 not in isolated_nodes else isolated_nodes
  #isolated_nodes.append(top_node) if top_node not in isolated_nodes else isolated_nodes
  # print(isolated_nodes)
  # Remove all isolated nodes
  if directed_graph.has_node(0):
    directed_graph.remove_node(0)
  if directed_graph.has_node(top_node):
    directed_graph.remove_node(top_node)

  isolated_nodes=get_isolated_nodes(directed_graph)
  #print("Isolated Nodes", isolated_nodes)
  directed_graph.remove_nodes_from(isolated_nodes)
  new_graph = nx.Graph()
  # Copy edges from directed graph
  # new_edges = directed_graph.edges()
 # print(len(list(directed_graph.edges())))
  new_edges = [(u-1, v-1) for (u,v) in list(directed_graph.edges())]
  #print("num new edges : ", len(new_edges))
  for node in list(directed_graph.nodes()):
    if len(list(directed_graph.predecessors(node))) == 0:
        if len(list(directed_graph.successors(node))) == 0:
   #         print("PRUNED GRAPH DISCONNECTED")
            directed_graph.remove_nodes_from([node])
  # print(directed_graph.nodes())
  # nx.draw_networkx(directed_graph)
  # Copy nodes and edges to new graph

  # nx.write_graphml_lxml(directed_graph, 'conv{}_{}_pruned_fwd_withweights.graphml'.format(num,measure))

  new_nodes = (list(directed_graph.nodes()))
  new_nodes = list(np.subtract(new_nodes,1))

  # Copy nodes and edges
  new_graph.add_nodes_from(new_nodes)
  new_graph.add_edges_from(new_edges)
  # Make sure labels start at 0
  new_undirected_graph = nx.convert_node_labels_to_integers(new_graph,ordering = 'sorted')
  #print("Prepare_new_graph_over")
  return new_undirected_graph


In [None]:
def prune_graph(measure, seed,x,y,topcount,bottomcount,cfg):
  %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS/
  global kval,pval

  # weights = {}
  # else:
  conv_2 = (nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv2u.graphml'.format(kval,pval,seed,cfg['NN']['REGIME']),node_type=int))
  conv_3 = (nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv3u.graphml'.format(kval,pval,seed,cfg['NN']['REGIME']),node_type=int))
  conv_4 = (nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv4u.graphml'.format(kval,pval,seed,cfg['NN']['REGIME']),node_type=int))
  conv_5 = (nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv5u.graphml'.format(kval,pval,seed,cfg['NN']['REGIME']),node_type=int))

  weights_2= nx.get_edge_attributes(conv_2,measure+"_norm")
  w2 = dict(sorted(weights_2.items(), key=lambda i:(i[1],i[0]))).keys()
  threshold2 = int(np.floor(np.percentile(np.arange(len(w2)), y)))
  # print("threshold 2 :",threshold2, "len w2 : ",len(w2))
  # print(w2)

  weights_3= nx.get_edge_attributes(conv_3,measure+"_norm")
  w3 = dict(sorted(weights_3.items(), key=lambda i:(i[1],i[0]))).keys()
  threshold3 = int(np.floor(np.percentile(np.arange(len(w3)), x)))

  weights_4= nx.get_edge_attributes(conv_4,measure+"_norm")
  w4 = dict(sorted(weights_4.items(), key=lambda i:(i[1],i[0]))).keys()
  threshold4 = int(np.floor(np.percentile(np.arange(len(w4)), x)))

  weights_5= nx.get_edge_attributes(conv_5,measure+"_norm")
  w5 = dict(sorted(weights_5.items(), key=lambda i:(i[1],i[0]))).keys()
  threshold5 = int(np.floor(np.percentile(np.arange(len(w5)), x)))

  conv_2_pruned = surgery(conv_2, threshold2,w2)
  # print("conv 2 pruned :",conv_2_pruned.nodes(),conv_2_pruned.edges())
  conv_3_pruned = surgery(conv_3, threshold3,w3)
  # print("conv 3 pruned :",conv_3_pruned.nodes(),conv_3_pruned.edges())
  conv_4_pruned = surgery(conv_4, threshold4,w4)
  conv_5_pruned = surgery(conv_5, threshold5,w5)

  new_conv_2 = prepare_new_graph(conv_2_pruned)
  # print("new conv 2 : ", new_conv_2.nodes(),new_conv_2.edges())
  new_conv_3 = prepare_new_graph(conv_3_pruned)
  # print("new conv 3 : ", new_conv_3.nodes(),new_conv_3.edges())
  new_conv_4 = prepare_new_graph(conv_4_pruned)
  new_conv_5 = prepare_new_graph(conv_5_pruned)

  %cd /content/drive/MyDrive/PRUNING/PRUNED/WS

  nx.write_graphml_lxml(new_conv_2, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv2.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure))
  nx.write_graphml_lxml(new_conv_3, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv3.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure))
  nx.write_graphml_lxml(new_conv_4, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv4.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure))
  nx.write_graphml_lxml(new_conv_5, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv5.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure))

  nx.write_graphml_lxml(new_conv_2, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv2_{}_b{}_t{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure,measure,bottomcount,topcount))
  nx.write_graphml_lxml(new_conv_3, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv3_{}_b{}_t{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure,measure,bottomcount,topcount))
  nx.write_graphml_lxml(new_conv_4, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv4_{}_b{}_t{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure,measure,bottomcount,topcount))
  nx.write_graphml_lxml(new_conv_5, './WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}/conv5_{}_b{}_t{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],measure,measure,bottomcount,topcount))

  print(len(list(new_conv_2.nodes())), len(list(new_conv_3.nodes())), len(list(new_conv_2.edges())), len(list(new_conv_3.edges())), nx.is_connected(new_conv_2), nx.is_connected(new_conv_3) )
  return len(list(new_conv_2.nodes())), len(list(new_conv_3.nodes())), len(list(new_conv_2.edges())), len(list(new_conv_3.edges())), nx.is_connected(new_conv_2), nx.is_connected(new_conv_3)

In [None]:
# This function will convert the undirected graph to DAG by the following steps:
# step 1 : assign the directions from node of lesser integer id to the node of greater integer id


def to_dag(undirected_graph):

  #changes
	# Initialise graph
  directed_graph = nx.DiGraph()

  for s,t in undirected_graph.edges():
    source = min(s,t)
    dest = max(s,t)
    directed_graph.add_edge(source, dest)

  return directed_graph

In [None]:
# This function will convert the undirected graph into DAG by,
# step 1: adding a source node and a sink node
# step 2: assigning direction from the node with lesser integer id to greater integer id
# step 3: connect the nodes without predecessors to the source node(0)
# step 4: connect the nodes without successors to the sink node(either 16 or 33 depending on if the graph passed in conv2 or conv3,4,5)

def un_to_dag(undirected_graph):


  directed_graph = nx.DiGraph()
  for i in range(len(list(undirected_graph))):
    neighbors = list(undirected_graph.neighbors(i))
    for n in neighbors:
      if  n > i:
        directed_graph.add_edge(i+1, n+1)
      elif i > n:
        directed_graph.add_edge(n+1, i+1)
          
  for m in range(1, len(list(undirected_graph))+1):
    n_O = len(list(directed_graph.successors(m)))
    n_I = len(list(directed_graph.predecessors(m)))
    if n_O == 0:
 #     print(m," n_O")
      directed_graph.add_edge(m,len(list(undirected_graph))+1)
    if n_I == 0:
   #   print(m," n_I")
      directed_graph.add_edge(0, m)
    print(list(nx.selfloop_edges(directed_graph)))
    directed_graph.remove_edges_from(nx.selfloop_edges(directed_graph))

  # print(directed_graph.nodes)
  # print("Directed Edges")
  # print(directed_graph.edges)
  # print("UN-to-dag over")
  # print("*************")
  return directed_graph

In [None]:
"""
============================================================================================
A program to compute the Ollivier-Ricci curvature of a given directed unweighted graph.

References:
1) E. Saucan, R.P. Sreejith, R.P. Vivek-Ananth, J. Jost & A. Samal, Discrete Ricci curvatures for directed networks, Chaos, Solitons & Fractals, 118: 347-360 (2019).
2) Ni, C.-C., Lin, Y.-Y., Gao, J., Gu, X., & Saucan, E. (2015). Ricci curvature of the Internet topology (Vol. 26, pp. 2758-2766). Presented at the 2015 IEEE Conference on Computer Communications (INFOCOM), IEEE

The following is a modified version of the code that can be found in Chien-Chun Ni's github repository : https://github.com/saibalmars/GraphRicciCurvature

============================================================================================
"""


print ("="*75)
from multiprocessing import Pool,cpu_count
import sys
import importlib
import time
import cvxpy as cvx
import networkx as nx
import numpy as np
import datetime

# # Display date and time
# now = datetime.datetime.now()
# print ("Current date and time\n%s\n"%(now.strftime("%Y-%m-%d %H:%M:%S")))
# starttime = time.time()

# Opening output files for edge and node
# EF = open(sys.argv[2], 'w')
# NF = open(sys.argv[3], 'w')

#Creating the directed graph from the edge file
# Graph=nx.DiGraph()
# for i in open(sys.argv[1], 'r'):
#   e = i.strip().split('\t')
#   if e[0] != e[1]:
#     Graph.add_edge(e[0], e[1])

# edgesize=Graph.number_of_edges()
# nodesize=Graph.number_of_nodes()
# print ("> Created a graph with (%d) edges and (%d) nodes"%(edgesize, nodesize))

# Function for computing Olliver-Ricci curvature for a given edge.
#============================================================================================

def ricciCurvature_Edge(G, source, target, length, verbose):

  # Making list of neighbours of source and target node of the edge
  source_nbr = list(G.predecessors(source))
  target_nbr = list(G.successors(target))

  # Distributing mass to each of the neighbours of source node
  if not source_nbr:
    source_nbr.append(source)
    x = [1]
  else:
    x=[]
    for i in source_nbr:
      x.append(1.0/(G.in_degree(source)+1))
    source_nbr.append(source)
    x.append(1.0/(G.in_degree(source)+1))

  # Distributing mass to each of the neighbours of target node
  if not target_nbr:
    target_nbr.append(target)
    y = [1]
  else:
    y=[]
    for nbr in target_nbr:
                        y.append(1.0/(G.out_degree(target)+1))
    target_nbr.append(target)
    y.append(1.0/(G.out_degree(target)+1))


  # Construct the cost dictionary from x to y
  d = np.zeros((len(x), len(y)))
  for i, s in enumerate(source_nbr):
    for j, t in enumerate(target_nbr):
      assert t in length[s], "Target node not in list, should not happen, pair (%d, %d)" % (s, t)
      # Filling the shortest path length to corresponding edges
      d[i][j] = length[s][t]

  # The mass that source neighborhood initially owned
  x = np.array([x]).T
  # The mass that target neighborhood needs to received
  y = np.array([y]).T

  # The transportation plan rho
  rho = cvx.Variable(shape=(len(target_nbr), len(source_nbr)))

  # objective function d(x,y) * rho * x, need to do element-wise multiply here
  obj = cvx.Minimize(cvx.sum(cvx.multiply(np.multiply(d.T, x.T), rho)))

  # \sigma_i rho_{ij}=[1,1,...,1]
  source_sum = cvx.sum(rho, axis=0)
  constrains = [rho * x == y, source_sum == np.ones(len(source_nbr)), 0 <= rho, rho <= 1]
  prob = cvx.Problem(obj, constrains)
  m = prob.solve(solver='ECOS')

  if verbose:
    print(time.time() - t0, " secs for cvxpy.")

  # divided by the length of d(i, j)
  result = 1 - (m / length[source][target])
  #print source, target, result
  if verbose:
    print("#source_nbr: %d, #target_nbr: %d, Ricci curvature = %f" % (len(source_nbr), len(target_nbr), result))
  return {(source, target): result}


def _wrapRicci(stuff):
  return ricciCurvature_Edge(*stuff)


# Function for computing ricci curvature for all nodes and all edges in G.
#============================================================================================

def ricciCurvature(G, proc=cpu_count(), edge_list=[], verbose=False):
  # Construct the all pair shortest path lookup
  # t0 = time.time()
  length = dict(nx.all_pairs_dijkstra_path_length(G, weight=None))
  # print ("> Time taken for all pair shortest path = %smin (%fsec)"%(round((time.time() - t0)/60.0, 3), round((time.time() - t0), 5)))
  # t0 = time.time()
  # compute edge ricci curvature
  p = Pool(processes=proc)

  # if there is no assigned edges to compute, compute all edges instead
  if not edge_list:
    edge_list = G.edges()
  args = [(G, source, target, length, verbose) for source, target in edge_list]
  result = p.map_async(_wrapRicci, args)
  result = result.get()
  p.close()
  p.join()

  # assign edge Ricci curvature from result to graph G
  for rc in result:
    for k in list(rc.keys()):
      source, target = k
      G[source][target]['ORC'] = rc[k]
  # endtime = time.time()
  # print ("> Time taken for OR Curvature of edge = %smin (%fsec)"%(round((time.time() - t0)/60.0, 3), round((time.time() - t0), 5)))

  # compute node Ricci curvature
  print ("\nOlliver-Ricci curvature for node")

  for n in G.nodes():
    in_rcsum = 0  # sum of the incomming neighbor Ricci curvature
    if G.in_degree(n) != 0:
      for nbr in G.predecessors(n):
        if 'ORC' in G[nbr][n]:
          in_rcsum += G[nbr][n]['ORC']

    out_rcsum = 0  # sum of the outgoing neighbor Ricci curvature
    if G.out_degree(n) != 0:
      for nbr in G.successors(n):
        if 'ORC' in G[n][nbr]:
          out_rcsum += G[n][nbr]['ORC']

    # NF.write("%s\t%s\t%s\n"%(n, in_rcsum, out_rcsum))

  # endtime = time.time()
  # print ("> Time taken for OR Curvature of node = %smin (%fsec)"%(round((time.time() - t0)/60.0, 3), round((time.time() - t0), 5)))
  return G

#============================================================================================
# Calling the main function to compute Olliver-Ricci curvature of the graph

# print ("\nComputation for Olliver-Ricci curvature for edge started")
# Elist=Graph.edges()# Edge list of the graph
# ricciCurvature(Graph, proc=10, edge_list=[], verbose=False)

In [None]:
def FR_Dir(G):
  for i in G.edges():
    u=i[0]
    v=i[1]
    fc=0
    fc+=2
    in_u=list(G.in_edges(u))
    fc-=len(in_u)
    out_v=list(G.out_edges(v))
    fc-=len(out_v)
    G[u][v]['FR']=fc
  return G

In [None]:
import networkx as nx
def calc_edge_measures(G):

  G = to_dag(G)
  G = ricciCurvature(G)
  G = FR_Dir(G)
  ebc = nx.edge_betweenness_centrality(G)
  for (u,v), value in ebc.items() : G[u][v]['EBC'] = value

  return G

In [None]:
def normalize_weights(G):

  for measure in ['FR','AFR','ORC','EBC']:
    weights = list(nx.get_edge_attributes(G,measure).values())
    # print("measure",measure," ",weights)
    minval = np.min(weights)
    maxval = np.max(weights)
    # if measure=="EBC":
      # print("MIN",minval)
      # print("MAX",maxval)
    if (measure  == 'EBC'):
      for (u,v) in G.edges():
        x = G[u][v][measure]
        #print("X",x)
        G[u][v][measure+"_norm"] = ((x-minval+1)/(maxval - minval+2))
        #print("X",x," ",G[u][v][measure+"_norm"] )
    else:
      for (u,v) in G.edges():
        x = G[u][v][measure]
        G[u][v][measure+"_norm"] = ((maxval-x+1)/(maxval-minval+2))
  return G

In [None]:
def normalized_edge_measures_dir(cfg,seed):

  %cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS/
  global kval,pval

  conv_2 = un_to_dag(nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv2_{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],cfg['RND_SEED']),node_type=int))
  conv_3 = un_to_dag(nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv3_{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],cfg['RND_SEED']),node_type=int))
  conv_4 = un_to_dag(nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv4_{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],cfg['RND_SEED']),node_type=int))
  conv_5 = un_to_dag(nx.read_graphml('./WS_K_{}/WS_P_{}/seed_{}/{}/conv5_{}.graphml'.format(kval,pval,seed,cfg['NN']['REGIME'],cfg['RND_SEED']),node_type=int))


  conv_2 = calc_edge_measures(conv_2).copy()
  conv_3=  calc_edge_measures(conv_3).copy()
  conv_4 = calc_edge_measures(conv_4).copy()
  conv_5 = calc_edge_measures(conv_5).copy()

  conv_2_ = normalize_weights(conv_2)
  conv_3_ = normalize_weights(conv_3)
  conv_4_ = normalize_weights(conv_4)
  conv_5_ = normalize_weights(conv_5)


  nx.write_graphml(conv_2_, "./WS_K_{}/WS_P_{}/seed_{}/{}/conv2u.graphml".format(kval,pval,seed,cfg['NN']['REGIME']))
  nx.write_graphml(conv_3_, "./WS_K_{}/WS_P_{}/seed_{}/{}/conv3u.graphml".format(kval,pval,seed,cfg['NN']['REGIME']))
  nx.write_graphml(conv_4_, "./WS_K_{}/WS_P_{}/seed_{}/{}/conv4u.graphml".format(kval,pval,seed,cfg['NN']['REGIME']))
  nx.write_graphml(conv_5_, "./WS_K_{}/WS_P_{}/seed_{}/{}/conv5u.graphml".format(kval,pval,seed,cfg['NN']['REGIME']))


In [None]:
def Edgeparams_BT(cfg, seed,train_loader,val_loader):

  %cd /content/drive/MyDrive/PRUNING/PRUNED/WS

  global measures,pval,kval

  baseline_acc,baseline_sensitivity,baseline_specificity,conv2nodes,conv2edges,conv3nodes,conv3edges,flops_value,param_value = calc_unpruned_values(cfg,seed,pval,kval)
  #baseline_acc = 0.8085
  normalized_edge_measures_dir(cfg, seed)
  # prune_graph("FR",14, 50,1)
  df1 = pd.DataFrame(columns=["randseed","measure","bottom_depth",'top_depth', "xvalue","yvalue", "nodes2","edges2","nodes3","edges3","act_nodes_2","act_edges2","act_nodes_3","act_edges3","connectivity_conv2","connectivity_conv3","pruning_accuracy","baseline_acc","pruning_sensitivity","baseline_sensitivity","pruning_specificity","baseline_specificity","pruning_flops","baseline_flops","pruning_parameters","baseline_parameters","%edges","%flops","%parameters"])
  i = 0
  %cd /content/drive/MyDrive/PRUNING/PRUNED/WS
  for measure in measures:
    if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}".format(kval,pval,num,cfg['NN']['REGIME'],measure))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}".format(kval,pval,num,cfg['NN']['REGIME'],measure))
    print("MEASURE : ",measure)
    #count = 1

    top_count = 1
    best_acc = baseline_acc
     #To run from the last breaking point
    if os.path.exists("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_torun_x.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure)):
      with open("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_torun_x.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure)) as f:
        lis = [line.split() for line in f]
        print(lis)
        xmin=float(lis[0][0])
        xmax=float(lis[1][0])
        x=float(lis[2][0])
        best_x=float(lis[3][0])
        bottom_count=int(lis[4][0])
        i=int(lis[5][0])
        f.close()
    else:
      xmin = 0
      xmax = 100
      x = (xmin + xmax)/2
      # prev_acc= baseline_acc
      best_x = 0
      bottom_count = 1

    if os.path.exists("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_stats.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure)):
      df1=pd.read_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_stats.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure))
      print("Inside if")
      print(df1)

    best_x = 0

    y = 0
    while (bottom_count <= 5):
      print("measure: ",measure, " topcount : ", top_count, " bottomcount : ", bottom_count," ; x value : ", x)
      nodes2, nodes3, edges2, edges3, connectivity_conv2, connectivity_conv3 = prune_graph(measure,seed, x,y,top_count,bottom_count,cfg)

      top1s,sensitivity,specificity, flops,param = env.step(cfg,seed,train_loader,val_loader ,x,y, measure)

      if round(top1s,4) >= baseline_acc and round(sensitivity,4)>=baseline_sensitivity and round(specificity,4)>=baseline_specificity:
        if (x > best_x):
            best_x = x
      print("best_x : ", best_x)

      df1 = df1.append(pd.DataFrame([[seed,measure,bottom_count,top_count,x,y,nodes2,edges2,nodes3,edges3,conv2nodes,conv2edges,conv3nodes, conv3edges,connectivity_conv2,connectivity_conv3,top1s,baseline_acc,sensitivity,baseline_sensitivity,specificity,baseline_specificity,flops_value,flops,param,param_value,((edges2+(3*edges3))/(conv2edges + (3*conv3edges)))*100, (flops/flops_value)*100,(param/param_value)*100]],columns=list(df1.columns)),ignore_index=True)
      if (bottom_count < 5):
          
        if round(top1s,4) >= baseline_acc and round(sensitivity,4)>=baseline_sensitivity and round(specificity,4)>=baseline_specificity:
          xmin = x
          xmax = xmax
          x = (xmin + xmax)/2
        else:
          xmin = xmin
          xmax = x
          x = (xmin + xmax)/2

      bottom_count += 1
      i += 1
      #writing the last updated value
      nex_run=[[xmin],[xmax],[x],[best_x],[bottom_count],[i]]
      nex_file = open("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_torun_x.csv".format(kval,pval,num,cfg['NN']['REGIME'],measure), 'w', newline ='')
      with nex_file:
        write = csv.writer(nex_file)
        write.writerows(nex_run)
      nex_file.close()
      df1.to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_stats.csv".format(kval,pval,seed,cfg['NN']['REGIME'],measure), header = True, index = False)
    best_val=np.array([best_x,y])

    Best_DF = pd.DataFrame(best_val)
    Best_DF.to_csv("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/{}_Best.csv".format(kval,pval,seed,cfg['NN']['REGIME'],measure), header = True, index = False)


In [None]:
train_loader=torch.load('/content/drive/MyDrive/Outputs/train_data.pth')
val_loader=torch.load('/content/drive/MyDrive/Outputs/val_data.pth')

In [None]:
## Training

import numpy as np
from datetime import datetime
import os
%cd /content/drive/MyDrive/
randseeds = [3, 16, 34, 57, 59, 61, 66, 72, 92, 97]
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Start  Time =", current_time)
net_stats={}
env = set_environment()
cfg = get_configuration()
%cd /content/drive/MyDrive/PRUNING/UNPRUNED/WS
cfg["USE_PRUNED_GRAPH"] = False
cfg["MAKE_GRAPH"] = True
cfg["EPOCH"] = 100 
cfg["GRAPH_MODEL"] = "WS"  # 'ER', 'BA'
cfg["WS_K"] = 4
cfg["WS_P"]=0.75
pval = cfg['WS_P']
kval = cfg['WS_K']
cfg['NN']['REGIME'] = 'SMALL'
prepare(cfg)
seed = 0
# count=2
if (cfg['MAKE_GRAPH']):
  global seed
  for num in randseeds:
    #print("Seed :",num)
    global net_stats
    net_stats={}
    cfg['RND_SEED'] = num
    seed = num

    if not os.path.isdir("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/".format(kval,pval)):
      os.makedirs("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/".format(kval,pval))

    if not os.path.isdir("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/seed_{}".format(kval,pval,num)):
      os.makedirs("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/seed_{}".format(kval,pval,num))

    if not os.path.isdir("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/seed_{}/{}".format(kval,pval,num,cfg['NN']['REGIME'])):
      os.makedirs("/content/drive/MyDrive/PRUNING/UNPRUNED/WS/WS_K_{}/WS_P_{}/seed_{}/{}".format(kval,pval,num,cfg['NN']['REGIME']))

    top_1s,sensetivity,specificity, flops, params = env.step(cfg,num,train_loader,val_loader)

now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("End Time =", current_time)


In [None]:
## Pruning

import numpy as np
from datetime import datetime
import os
%cd /content/drive/MyDrive/
randseeds = [3, 16, 34, 57, 59, 61, 66, 72, 92, 97]
measures= ['FR'] #'ORC', 'EBC'
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Start  Time =", current_time)
net_stats={}
env = set_environment()
cfg = get_configuration()
%cd /content/drive/MyDrive/PRUNING/PRUNED/WS
cfg["USE_PRUNED_GRAPH"] = True
cfg["MAKE_GRAPH"] = False
cfg["EPOCH"] = 2
cfg["GRAPH_MODEL"] = "WS"
cfg["WS_K"] = 4
cfg["WS_P"]=0.75
pval = cfg['WS_P']
kval = cfg['WS_K']
cfg['NN']['REGIME'] = 'SMALL'
prepare(cfg)
seed = 0
# count=2

if not cfg["MAKE_GRAPH"]:
  for num in randseeds:
    seed=num
    print("RANDOM SEED : ", num)
    if not os.path.isdir("/content/drive/MyDrive/PRUNING/PRUNED/WS/WS_K_{}/".format(cfg['WS_K'])):
      os.makedirs("/content/drive/MyDrive/PRUNING/PRUNED/WS/WS_K_{}/".format(cfg['WS_K']))

    if not os.path.isdir("/content/drive/MyDrive/PRUNING/PRUNED/WS/WS_K_{}/WS_P_{}/".format(kval,pval)):
      os.makedirs("/content/drive/MyDrive/PRUNING/PRUNED/WS/WS_K_{}/WS_P_{}/".format(kval,pval))

    if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}".format(kval,pval,num))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}".format(kval,pval,num))

    if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}/{}".format(kval,pval,num,cfg['NN']['REGIME']))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}/{}".format(kval,pval,num,cfg['NN']['REGIME']))

    if (not os.path.isdir("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/".format(kval,pval,num,cfg['NN']['REGIME']))):
        os.makedirs("./WS_K_{}/WS_P_{}/seed_{}/{}/output_bt/".format(kval,pval,num,cfg['NN']['REGIME']))

    cfg['RND_SEED'] = num
    Edgeparams_BT(cfg, num,train_loader,val_loader)
    print("Random seed: ",num," over")

now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("End Time =", current_time)
